Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 34 additions & 37 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,40 @@ private async Task CompareEmbeddings(string modelPath)

var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, spoon);

if (false)
{
//TODO: the below does not work with the new memory efficient context handling - we probably need to define Microsoft.Extensions.AI.IEmbeddingGenerator GetService interface that creates the context on the fly

var generator = (IEmbeddingGenerator<string, Embedding<float>>)embedder;
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>());
Assert.Equal(nameof(LLamaEmbedder), generator.GetService<EmbeddingGeneratorMetadata>()?.ProviderName);
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>()?.DefaultModelId);
Assert.NotEmpty(generator.GetService<EmbeddingGeneratorMetadata>()?.DefaultModelId!);
Assert.Same(embedder, generator.GetService<LLamaEmbedder>());
Assert.Same(generator, generator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
Assert.Null(generator.GetService<string>());

var embeddings = await generator.GenerateAsync(
[
"The cat is cute",
"The kitten is cute",
"The spoon is not real"
]);
Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");

var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);

_testOutputHelper.WriteLine("");
_testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}");
_testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}");

Assert.True(close < far);
}

using var context = new LLamaContext(weights, @params);
var managedEmbedder = new LLamaEmbedder(context);
IEmbeddingGenerator<string, Embedding<float>> generator = managedEmbedder;
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>());
Assert.Equal(nameof(LLamaEmbedder), generator.GetService<EmbeddingGeneratorMetadata>()?.ProviderName);
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>()?.DefaultModelId);
Assert.NotEmpty(generator.GetService<EmbeddingGeneratorMetadata>()?.DefaultModelId!);
Assert.Same(managedEmbedder, generator.GetService<LLamaEmbedder>());
Assert.Same(generator, generator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
Assert.Null(generator.GetService<string>());

var embeddings = await generator.GenerateAsync(
[
"The cat is cute",
"The kitten is cute",
"The spoon is not real"
]);
Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");

var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);

_testOutputHelper.WriteLine("");
_testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}");
_testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}");

Assert.True(close < far);
}

[Fact]
Expand Down
42 changes: 19 additions & 23 deletions LLama/LLamaEmbedder.EmbeddingGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using Microsoft.Extensions.AI;

namespace LLama;
Expand All @@ -16,25 +15,27 @@ public partial class LLamaEmbedder
/// <inheritdoc />
object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey)
{
if (serviceKey is null)
if (serviceKey is not null)
{
if (serviceType == typeof(EmbeddingGeneratorMetadata))
{
return _metadata ??= new(
nameof(LLamaEmbedder),
defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null,
defaultModelDimensions: EmbeddingSize);
}
return null;
}

if (_hasExternalContext && serviceType == typeof(EmbeddingGeneratorMetadata))
{
return _metadata ??= new(
nameof(LLamaEmbedder),
defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null,
defaultModelDimensions: EmbeddingSize);
}

if (serviceType?.IsInstanceOfType(Context) is true)
{
return Context;
}
if (_hasExternalContext && serviceType?.IsInstanceOfType(Context) is true)
{
return Context;
}

if (serviceType?.IsInstanceOfType(this) is true)
{
return this;
}
if (serviceType?.IsInstanceOfType(this) is true)
{
return this;
}

return null;
Expand All @@ -43,11 +44,6 @@ public partial class LLamaEmbedder
/// <inheritdoc />
async Task<GeneratedEmbeddings<Embedding<float>>> IEmbeddingGenerator<string, Embedding<float>>.GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)
{
if (Context.NativeHandle.PoolingType == LLamaPoolingType.None)
{
throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}.");
}

GeneratedEmbeddings<Embedding<float>> results = new()
{
Usage = new() { InputTokenCount = 0 },
Expand All @@ -56,7 +52,7 @@ async Task<GeneratedEmbeddings<Embedding<float>>> IEmbeddingGenerator<string, Em
foreach (var value in values)
{
var (embeddings, tokenCount) = await GetEmbeddingsWithTokenCount(value, cancellationToken).ConfigureAwait(false);
Debug.Assert(embeddings.Count == 1, "Should be one and only one embedding when pooling is enabled.");
Debug.Assert(embeddings.Count == 1, "Should be one and only one embedding returned from LLama for a single input string.");

results.Usage.InputTokenCount += tokenCount;
results.Add(new Embedding<float>(embeddings[0]) { CreatedAt = DateTime.UtcNow });
Expand Down
70 changes: 53 additions & 17 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Native;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using static System.Net.Mime.MediaTypeNames;

namespace LLama;

Expand All @@ -26,18 +23,26 @@ public sealed partial class LLamaEmbedder
/// <summary>
/// LLama Context
/// </summary>
/// <remarks>
/// If the context was not provided externally, the returned context will be in a disposed state.
/// </remarks>
public LLamaContext Context { get; private set; }

private LLamaWeights _weights;
private IContextParams _params;
private ILogger? _logger;
private readonly LLamaWeights? _weights;
private readonly IContextParams _params;
private readonly ILogger? _logger;
private readonly bool _hasExternalContext;

/// <summary>
/// Create a new embedder, using the given LLamaWeights
/// Create a new embedder, using the given <see cref="LLamaWeights"/>.
/// This will create and dispose a new <see cref="LLamaContext"/> for each embedding request.
/// If you want to manage the context lifetime yourself, consider using the other constructor that takes a <see cref="LLamaContext"/>.
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
/// <param name="logger"></param>
/// <param name="weights">weights to use for generating embeddings. The weights must be for a model that supports embeddings (i.e. it must have an encoder or a decoder, but not both).</param>
/// <param name="params">context parameters to use when creating the context</param>
/// <param name="logger">optional logger</param>
/// <exception cref="ArgumentException">raised if the provided context has batch size different from ubatch size</exception>
/// <exception cref="NotSupportedException">raised if the provided context is for an encoder-decoder model</exception>
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
if (@params.UBatchSize != @params.BatchSize)
Expand All @@ -51,12 +56,39 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg
_weights = weights;
_params = @params;
_logger = logger;
_hasExternalContext = false;
}

/// <summary>
/// Creates a new embedder using the given <see cref="LLamaContext"/>.
/// The caller is responsible for managing the lifetime of the context, and must ensure that the context remains valid
/// for the entire lifetime of this <see cref="LLamaEmbedder"/>. The context will not be disposed when this embedder is disposed.
/// </summary>
/// <param name="context">context to use for generating embeddings. The context must be configured with a model that supports embeddings (i.e. it must have an encoder or a decoder, but not both).</param>
/// <param name="logger">optional logger</param>
/// <exception cref="ArgumentException">raised if the provided context has batch size different from ubatch size</exception>
/// <exception cref="NotSupportedException">raised if the provided context is for an encoder-decoder model</exception>
public LLamaEmbedder(LLamaContext context, ILogger? logger = null)
{
if(context.Params.UBatchSize != context.Params.BatchSize)
throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(context));

if (context.NativeHandle.ModelHandle is { HasEncoder: true, HasDecoder: true })
throw new NotSupportedException("Computing embeddings in encoder-decoder models is not supported");

Context = context;
EmbeddingSize = Context.EmbeddingSize;
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
_params = context.Params;
_logger = logger;
_hasExternalContext = true;
}

/// <inheritdoc />
public void Dispose()
{
Context.Dispose();
if(!_hasExternalContext && !Context.NativeHandle.IsClosed)
Context.Dispose();
}

/// <summary>
Expand All @@ -72,14 +104,17 @@ public void Dispose()
public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, CancellationToken cancellationToken = default) =>
(await GetEmbeddingsWithTokenCount(input, cancellationToken).ConfigureAwait(false)).Embeddings;


private async Task<(IReadOnlyList<float[]> Embeddings, int Tokens)> GetEmbeddingsWithTokenCount(string input, CancellationToken cancellationToken = default)
{
// Ensure the context from last time is disposed (it always should be)
if (!Context.NativeHandle.IsClosed)
Context.Dispose();
if (!_hasExternalContext)
{
if (!Context.NativeHandle.IsClosed)
Context.Dispose();

Context = _weights.CreateContext(_params, _logger);
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
Context = _weights!.CreateContext(_params, _logger);
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
}

// Add all of the tokens to the batch
var tokens = Context.Tokenize(input, special: true);
Expand Down Expand Up @@ -150,7 +185,8 @@ public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, Cancellati
embedding.EuclideanNormalization();
}

Context.Dispose();
if (!_hasExternalContext)
Context.Dispose();

return (results, tokens.Length);
}
Expand Down