Skip to content
This repository was archived by the owner on Aug 29, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 240 additions & 6 deletions ITokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.RegularExpressions;
using System.Threading.Tasks;

namespace LLAMA;
Expand Down Expand Up @@ -33,6 +35,19 @@ public override NormalizedString Normalize(string original)
}
}

public class TikTokenNormalizer : Normalizer
{
public override NormalizedString Normalize(string original)
{
// replace newline with Ċ
var normalized = original.Replace(Environment.NewLine, "Ċ");
// replace whitespace with Ġ
normalized = normalized.Replace(' ', 'Ġ');

return new NormalizedString(original, normalized, null, isOneToOneMapping: true);
}
}

public class PreTokenizer : Microsoft.ML.Tokenizers.PreTokenizer
{
public override IReadOnlyList<Split> PreTokenize(string sentence)
Expand All @@ -43,22 +58,52 @@ public override IReadOnlyList<Split> PreTokenize(string sentence)
}
}

public class SplitPreTokenizer : Microsoft.ML.Tokenizers.PreTokenizer
{
private readonly string _pattern;

public SplitPreTokenizer(string pattern)
{
this._pattern = pattern;
}

public override IReadOnlyList<Split> PreTokenize(string? sentence)
{
if (sentence == null)
{
return [];
}

List<Split> list = new List<Split>();
foreach (Match item in Regex.Matches(sentence, _pattern))
{
list.Add(new Split(item.Value, (item.Index, item.Index + item.Length)));
}

return list;
}
}

public class TokenizeDecoder : Microsoft.ML.Tokenizers.TokenizerDecoder
{
private const char spaceReplacement = '▁';
private char spaceReplacement = '▁';
private char newlineReplacement = 'Ċ';
private string bos = "<s>";
private string eos = "</s>";

public TokenizeDecoder(string bos = "<s>", string eos = "</s>")
public TokenizeDecoder(string bos = "<s>", string eos = "</s>", char spaceReplacement = '▁', char newlineReplacement = 'Ċ')
{
this.bos = bos;
this.eos = eos;
this.spaceReplacement = spaceReplacement;
this.newlineReplacement = newlineReplacement;
}

public override string Decode(IEnumerable<string> tokens)
{
var str = string.Join("", tokens);
str = str.Replace(spaceReplacement, ' ');
str = str.Replace(newlineReplacement.ToString(), Environment.NewLine);

if (str.StartsWith(bos))
{
Expand All @@ -74,12 +119,12 @@ public override string Decode(IEnumerable<string> tokens)
}
}

public class BPETokenizer : ITokenizer
public class LLama2Tokenizer : ITokenizer
{
private Tokenizer tokenizer;
private bool addPrecedingSpace;

public BPETokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2)
public LLama2Tokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2)
{
this.BosId = startToken;
this.EosId = endToken;
Expand All @@ -91,13 +136,75 @@ public BPETokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace
this.tokenizer.Decoder = decoder;
}

public static BPETokenizer FromPretrained(
public LLama2Tokenizer(Dictionary<string, int> vocab, List<string> merges, bool addPrecedingSpace = true, int padToken = -1, int startToken = 1, int endToken = 2)
{
this.BosId = startToken;
this.EosId = endToken;
this.addPrecedingSpace = addPrecedingSpace;
this.PadId = padToken;
// save vocab to vocab-temp.json
var vocabTempPath = "vocab-temp.json";
var json = JsonSerializer.Serialize(vocab);
File.WriteAllText(vocabTempPath, json);

// save merges to merges-temp.txt
var mergesTempPath = "merges-temp.txt";
File.WriteAllLines(mergesTempPath, merges);

var bpe = new Bpe(vocabTempPath, mergesTempPath);

this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new Norm());
var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!);
this.tokenizer.Decoder = decoder;

// delete temp files
File.Delete(vocabTempPath);
File.Delete(mergesTempPath);
}

public static LLama2Tokenizer FromPretrained(
string folder,
string tokenizerJsonPath = "tokenizer.json"
)
{
throw new NotImplementedException();
tokenizerJsonPath = Path.Combine(folder, tokenizerJsonPath);
var json = File.ReadAllText(tokenizerJsonPath);
var jsonDocument = JsonDocument.Parse(json);
// vocab: .model.vocab
var vocabNode = jsonDocument.RootElement.GetProperty("model").GetProperty("vocab");

// to Dictionary<string, int>
var vocab = new Dictionary<string, int>();
foreach (var item in vocabNode.EnumerateObject())
{
vocab[item.Name] = item.Value.GetInt32();
}

// added tokens: .added_tokens
var addedTokensNode = jsonDocument.RootElement.GetProperty("added_tokens");
foreach (var item in addedTokensNode.EnumerateArray())
{
// get id from item.id
var id = item.GetProperty("id").GetInt32();
var content = item.GetProperty("content").GetString()!;
vocab[content] = id;
}

// merges: .model.merges
var mergesNode = jsonDocument.RootElement.GetProperty("model").GetProperty("merges");
// merges: List<string>
var merges = new List<string>();
foreach (var item in mergesNode.EnumerateArray())
{
merges.Add(item.GetString()!);
}

var startToken = vocab["<|begin_of_text|>"];
var endToken = vocab["<|end_of_text|>"];

return new LLama2Tokenizer(vocab, merges, startToken: startToken, endToken: endToken);
}

public int VocabSize => this.tokenizer.Model.GetVocabSize();

public int PadId { get; }
Expand Down Expand Up @@ -138,3 +245,130 @@ public int[] Encode(string input, bool bos, bool eos)
return tokens;
}
}

public class LLama3Tokenizer : ITokenizer
{
private Tokenizer tokenizer;
private bool addPrecedingSpace;

public LLama3Tokenizer(string vocabPath, string mergesPath, bool addPrecedingSpace = false, int padToken = -1, int startToken = 1, int endToken = 2)
{
this.BosId = startToken;
this.EosId = endToken;
this.addPrecedingSpace = addPrecedingSpace;
this.PadId = padToken;
var bpe = new Bpe(vocabPath, mergesPath);
this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new TikTokenNormalizer());
var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!, 'Ġ');
this.tokenizer.Decoder = decoder;
}

public LLama3Tokenizer(Dictionary<string, int> vocab, List<string> merges, bool addPrecedingSpace = false, int padToken = -1, int startToken = 1, int endToken = 2)
{
this.BosId = startToken;
this.EosId = endToken;
this.addPrecedingSpace = addPrecedingSpace;
this.PadId = padToken;
// save vocab to vocab-temp.json
var vocabTempPath = "vocab-temp.json";
var json = JsonSerializer.Serialize(vocab);
File.WriteAllText(vocabTempPath, json);

// save merges to merges-temp.txt
var mergesTempPath = "merges-temp.txt";
File.WriteAllLines(mergesTempPath, merges);

var bpe = new Bpe(vocabTempPath, mergesTempPath);
this.tokenizer = new Tokenizer(bpe, preTokenizer: new PreTokenizer(), normalizer: new TikTokenNormalizer());
var decoder = new TokenizeDecoder(this.tokenizer.Model.IdToToken(this.BosId)!, this.tokenizer.Model.IdToToken(this.EosId)!, 'Ġ');
this.tokenizer.Decoder = decoder;

// delete temp files
File.Delete(vocabTempPath);
File.Delete(mergesTempPath);
}

public static LLama3Tokenizer FromPretrained(
string folder,
string tokenizerJsonPath = "tokenizer.json"
)
{
tokenizerJsonPath = Path.Combine(folder, tokenizerJsonPath);
var json = File.ReadAllText(tokenizerJsonPath);
var jsonDocument = JsonDocument.Parse(json);
// vocab: .model.vocab
var vocabNode = jsonDocument.RootElement.GetProperty("model").GetProperty("vocab");

// to Dictionary<string, int>
var vocab = new Dictionary<string, int>();
foreach (var item in vocabNode.EnumerateObject())
{
vocab[item.Name] = item.Value.GetInt32();
}

// added tokens: .added_tokens
var addedTokensNode = jsonDocument.RootElement.GetProperty("added_tokens");
foreach (var item in addedTokensNode.EnumerateArray())
{
// get id from item.id
var id = item.GetProperty("id").GetInt32();
var content = item.GetProperty("content").GetString()!;
vocab[content] = id;
}

// merges: .model.merges
var mergesNode = jsonDocument.RootElement.GetProperty("model").GetProperty("merges");
// merges: List<string>
var merges = new List<string>();
foreach (var item in mergesNode.EnumerateArray())
{
merges.Add(item.GetString()!);
}

var startToken = vocab["<|begin_of_text|>"];
var endToken = vocab["<|end_of_text|>"];

return new LLama3Tokenizer(vocab, merges, startToken: startToken, endToken: endToken);
}

public int VocabSize => this.tokenizer.Model.GetVocabSize();

public int PadId { get; }

public int BosId { get; }

public int EosId { get; }

public string Decode(int[] input)
{
var str = this.tokenizer.Decode(input) ?? throw new Exception("Failed to decode");
if (this.addPrecedingSpace)
{
str = str.TrimStart();
}

return str;
}

public int[] Encode(string input, bool bos, bool eos)
{
if (this.addPrecedingSpace)
{
input = " " + input;
}
var tokens = this.tokenizer.Encode(input).Ids.ToArray();
if (bos)
{
tokens = new int[] { this.BosId }.Concat(tokens).ToArray();
}
if (eos)
{
tokens = tokens.Concat(new int[] { this.EosId }).ToArray();
}

Console.WriteLine($"tokens: {string.Join(",", tokens)}");

return tokens;
}
}

4 changes: 0 additions & 4 deletions LLAMA.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public static LLaMA Build(
modelArgs.VocabSize = tokenizer.VocabSize;
modelArgs.MaxSeqLen = maxSeqLen;
modelArgs.MaxBatchSize = maxBatchSize;
torch.set_default_dtype(torch.bfloat16);
// print model args
var modelArgsJson = JsonSerializer.Serialize(modelArgs, new JsonSerializerOptions { WriteIndented = true });
Console.WriteLine($"modelArgs: {modelArgsJson}");
Expand Down Expand Up @@ -132,9 +131,6 @@ public static LLaMA Build(
nextToken = nextToken.reshape(-1);
// # only replace token if prompt has already been generated
nextToken = torch.where(inputTextMask[.., curPos], tokens[.., curPos], nextToken);

// print nextToken
Console.WriteLine($"nextToken: {string.Join(",", nextToken.data<long>())}");
tokens[.., curPos] = nextToken;
if (logProbs)
{
Expand Down
3 changes: 1 addition & 2 deletions Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ public FeedForward(ModelArgs args)
{
var hiddenDim = args.Dim * 4;
hiddenDim = 2 * hiddenDim / 3;
hiddenDim = args.FFNDimMultiplier.HasValue ? (int)args.FFNDimMultiplier.Value * hiddenDim : hiddenDim;
hiddenDim = args.FFNDimMultiplier.HasValue ? (int)(args.FFNDimMultiplier.Value * hiddenDim) : hiddenDim;

// Round the hidden_dim to the nearest multiple of the multiple_of parameter
hiddenDim = args.MultipleOf * ((hiddenDim + args.MultipleOf - 1) / args.MultipleOf);
Expand Down Expand Up @@ -313,7 +313,6 @@ public override Tensor forward(Tensor tokens, int startPos)
var h = this.tok_embeddings.forward(tokens);
var freqsComplex = this.freqs_compex[startPos..(startPos + seqLen)].to(h.device);
Tensor? mask = null;
Console.WriteLine($"tokens shape: {string.Join(",", tokens.shape)}");

if (seqLen > 1)
{
Expand Down
Loading