Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 73 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
always_save_checkpoint = True # if True, always save a checkpoint after each eval
init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
# wandb logging
wandb_log = False # disabled by default
wandb_log = True # disabled by default
wandb_project = 'owt'
wandb_run_name = 'gpt2' # 'run' + str(time.time())

Expand All @@ -56,7 +56,7 @@
local_data = False # feeds local data
train_split_ratio = 0.8
gradient_accumulation_steps = 5 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
batch_size = 28 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 1024

# model
Expand Down Expand Up @@ -149,6 +149,7 @@ def get_dataloader(split: deeplake.Dataset, shuffle: bool = False, coef: float =
drop_last=True,
collate_fn=collate_fn)

data_dir = os.path.join('data', dataset)
if not local_data:
# split the dataset and construct dataloaders
ds = deeplake.load(dataset, read_only=True, token=token)
Expand Down Expand Up @@ -185,7 +186,6 @@ def get_batch(split: str):

else:
# poor man's data loader
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
def get_batch(split):
Expand All @@ -212,16 +212,24 @@ def get_batch(split):
iter_num = 0
best_val_loss = 1e9

# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
if os.path.exists(meta_path):
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
vocab_size = meta['vocab_size']
print(f"vocab_size = {vocab_size} (from {meta_path})")
else:
print(f"vocab_size not found in {meta_path}, using GPT-2 default of 50257")
vocab_size = 50257

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
bias=bias, vocab_size=vocab_size, dropout=dropout) # start with model_args from command line
if init_from == 'scratch':
# init a new model from scratch
print("Initializing a new model from scratch")
# determine the vocab size we'll use for from-scratch training
if meta_vocab_size is None:
print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
model_args['vocab_size'] = vocab_size
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
elif init_from == 'resume':
Expand Down Expand Up @@ -295,6 +303,44 @@ def estimate_loss():
model.train()
return out

def sample(model):
"""
Sample from a trained model
"""
import tiktoken

# -----------------------------------------------------------------------------
start = "\n" # or "<|endoftext|>" or whatever you like
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # higher temperature (up to 1) is more random, lower (down to 0) means more greedy
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability

model.eval()

# ok let's assume gpt-2 encodings by default
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)

# encode the beginning of the prompt
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation
with torch.no_grad():
gen_list = []
with ctx:
for k in range(num_samples):
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
g = decode(y[0].tolist())
gen_list.append(g)
print(g)
print('---------------')
return model.train(), gen_list


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
# 1) linear warmup for warmup_iters steps
Expand All @@ -311,8 +357,10 @@ def get_lr(it):

# logging
if wandb_log and master_process:
os.environ["WANDB_CONSOLE"]="wrap"
import wandb
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
wandb_table = wandb.Table(columns=["iter_num", "gen_1","gen_2","gen_3"])

# training loop
X, Y = get_batch('train') # fetch the very first batch
Expand All @@ -332,18 +380,26 @@ def get_lr(it):
losses = estimate_loss()
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
if wandb_log:
# model, gen_list = sample(model)

# for g in gen_list[:3]:
# wandb_table.add_data(iter_num, g[0], g[1], g[2])

wandb.log({
"iter": iter_num,
"train/loss": losses['train'],
"val/loss": losses['val'],
"lr": lr,
"mfu": running_mfu*100, # convert to percentage
"train/lr": lr,
# "table/generations": wandb_table,
"train/mfu": running_mfu*100, # convert to percentage
})
if losses['val'] < best_val_loss or always_save_checkpoint:
best_val_loss = losses['val']
if iter_num > 0:
checkpoint = {
'model': raw_model.state_dict(),
# do not save the optimizer, means that resuming training won't work so well
# 'optimizer': optimizer.state_dict(),
'optimizer': optimizer.state_dict(),
'model_args': model_args,
'iter_num': iter_num,
Expand Down Expand Up @@ -390,6 +446,13 @@ def get_lr(it):
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
# wandb logging frequency to every iteration
if iter_num % eval_interval != 0 and master_process:
wandb.log({
"iter": iter_num,
"train/loss": lossf,
"train/lr": lr
})
iter_num += 1
local_iter_num += 1

Expand Down