From e75c7544e8d2d7373ba596475043045fda578fac Mon Sep 17 00:00:00 2001 From: Artur Antonnikau Date: Sun, 11 Aug 2024 23:31:45 +0200 Subject: [PATCH 1/2] Add helpers --- examples/example_spec.rb | 91 +++++++++++++++++++ lib/rspec/llama/api_client.rb | 96 ++++++++++++++++++--- lib/rspec/llama/helpers.rb | 8 +- lib/rspec/llama/helpers/executor.rb | 58 +++++++++++++ lib/rspec/llama/helpers/resource_handler.rb | 76 ++++++++++++++++ spec/spec_helper.rb | 2 +- 6 files changed, 314 insertions(+), 17 deletions(-) create mode 100644 examples/example_spec.rb create mode 100644 lib/rspec/llama/helpers/executor.rb create mode 100644 lib/rspec/llama/helpers/resource_handler.rb diff --git a/examples/example_spec.rb b/examples/example_spec.rb new file mode 100644 index 0000000..bf83625 --- /dev/null +++ b/examples/example_spec.rb @@ -0,0 +1,91 @@ +# frozen_string_literal: true + +RSpec.configure do |config| + WebMock.allow_net_connect! + + config.api_endpoint = 'http://localhost:3000' + config.auth_endpoint = 'http://localhost:3000/users/sign_in' + config.api_creds = { user: { email: '***email***', password: '***password***' } } + + config.include Rspec::Llama::Helpers +end + +RSpec.describe 'Llama Rspec flow' do + let(:test_run_name) { 'fake_test_run_name' } + let(:model_version_name) { 'Version 1' } + + context 'when we want to set resources' do + before do + use_model('LLAMA', model_version_name) + use_prompt('pt_2') + use_assertion('ass_2') + end + + it 'should set model, prompt, assertion' do + expect(settled_model['name']).to eq('LLAMA') + expect(settled_model_version['build_name']).to eq(model_version_name) + expect(settled_prompt['name']).to eq('pt_2') + expect(settled_assertion['name']).to eq('ass_2') + end + end + + context 'when we want to create resources' do + before do + create_model('LLAMA_C', 'Vesion C') + create_prompt('pt_c', 'What is the capital of France?') + create_assertion('ass_c', 'Paris') + end + + it 'should create model, model_version, prompt, assertion' do + expect(settled_model['name']).to eq('LLAMA_C') + expect(settled_model_version['build_name']).to eq('Vesion C') + expect(settled_prompt['name']).to eq('pt_c') + expect(settled_assertion['name']).to eq('ass_c') + end + end + + describe '#execute_test_run' do + before do + use_model(model_name, model_version_name) + use_prompt('pt_2') + use_assertion('ass_2') + end + + let(:action) { execute_test_run(test_run_name, settled_prompt['id'], settled_assertion['id'], settled_model_version['id']) } + + context 'model version 1' do + let(:model_name) { 'LLAMA' } + let(:model_version_name) { 'Version 1' } + it 'should return a completed status' do + action + expect(settled_test_results.last['status']).to eq('completed') + expect(settled_assertion_results.first['state']).to eq('passed') + end + end + + context 'model version 2' do + let(:model_name) { 'LLAMA' } + let(:model_version_name) { 'Version 2' } + + it 'should return a completed status' do + action + expect(settled_test_results.last['status']).to eq('completed') + expect(settled_assertion_results.first['state']).to eq('passed') + end + + context 'create a new prompt & assertion' do + before do + create_prompt('pt_2', 'What is the capital of France?') + create_assertion('ass_2', 'Paris') + end + + it 'should return a completed status' do + action + expect(settled_test_results.last['status']).to eq('completed') + expect(settled_assertion_results.first['state']).to eq('passed') + end + end + end + end +end + diff --git a/lib/rspec/llama/api_client.rb b/lib/rspec/llama/api_client.rb index 2cacd3f..90bce2d 100644 --- a/lib/rspec/llama/api_client.rb +++ b/lib/rspec/llama/api_client.rb @@ -9,38 +9,108 @@ def initialize(endpoint, creds, auth_endpoint) @endpoint = endpoint @creds = creds @auth_endpoint = auth_endpoint + + set_token end def authenticate uri = URI(auth_endpoint) - req = Net::HTTP::Post.new(uri, 'Content-Type' => 'application/json') - req.body = creds.to_json + api_execute(:post, uri, creds) + end - response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == 'https') do |http| - http.request(req) - end + def fetch_all_models + uri = URI("#{endpoint}/models.json") + api_execute(:get, uri) + end + + def fetch_all_prompts + uri = URI("#{endpoint}/prompts.json") + api_execute(:get, uri) + end + + def fetch_all_assertions + uri = URI("#{endpoint}/assertions.json") + api_execute(:get, uri) + end + + def fetch_all_model_versions(model_id) + uri = URI("#{endpoint}/models/#{model_id}/model_versions.json") + api_execute(:get, uri) + end + + def fetch_assertion_results(test_result_id) + uri = URI("#{endpoint}/test_results/#{test_result_id}/assertion_results.json") + api_execute(:get, uri) + end + + def fetch_model(id) + uri = URI("#{endpoint}/models/#{id}.json") + api_execute(:get, uri) + end + + def fetch_test_run(test_run_id) + uri = URI("#{endpoint}/test_runs/#{test_run_id}.json") + api_execute(:get, uri) + end - response['Authorization'].gsub('Bearer ', '') + def fetch_test_model_version_run(test_model_version_run_id) + uri = URI("#{endpoint}/test_model_version_runs/#{test_model_version_run_id}.json") + api_execute(:get, uri) + end + + def fetch_test_result(test_result_id) + uri = URI("#{endpoint}/test_results/#{test_result_id}.json") + api_execute(:get, uri) + end + + def create_model(**opts) + uri = URI("#{endpoint}/models.json") + api_execute(:post, uri, opts) + end + + def create_model_version(model_id, opts) + uri = URI("#{endpoint}/models/#{model_id}/model_versions.json") + api_execute(:post, uri, opts) + end + + def create_prompt(model_id, **opts) + params = { prompt: { value: opts[:value], name: opts[:name] } } + uri = URI("#{endpoint}/prompts.json") + api_execute(:post, uri, params) + end + + def create_assertion(model_id, **opts) + uri = URI("#{endpoint}/assertions.json") + api_execute(:post, uri, opts) end def execute_test_run(params) uri = URI("#{endpoint}/test_runs.json") - req = Net::HTTP::Post.new(uri, 'Content-Type' => 'application/json') - req.body = params.to_json - req['Authorization'] = "Bearer #{token}" + api_execute(:post, uri, params) + end + + private + + def api_execute(method, uri, params = nil) + is_auth_req = uri.to_s.include?('sign_in') + + req = Net::HTTP.const_get(method.capitalize).new(uri, 'Content-Type' => 'application/json') + req['Authorization'] = "Bearer #{token}" unless is_auth_req + req.body = params.to_json if params response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == 'https') do |http| http.request(req) end - JSON.parse(response.body) + is_auth_req ? response : JSON.parse(response.body) + rescue StandardError => e + # should be implemented end - private - def token - @token ||= authenticate + @token ||= authenticate['Authorization'].gsub('Bearer ', '') end + alias_method :set_token, :token end end end diff --git a/lib/rspec/llama/helpers.rb b/lib/rspec/llama/helpers.rb index bb3b8ed..daf9faf 100644 --- a/lib/rspec/llama/helpers.rb +++ b/lib/rspec/llama/helpers.rb @@ -1,11 +1,13 @@ # frozen_string_literal: true +require_relative 'helpers/resource_handler' +require_relative 'helpers/executor' + module Rspec module Llama module Helpers - def execute_test_run(test_id) - Rspec::Llama.api_client.execute_test_run(test_id) - end + include ResourceHandler + include Executor end end end diff --git a/lib/rspec/llama/helpers/executor.rb b/lib/rspec/llama/helpers/executor.rb new file mode 100644 index 0000000..1f4c8eb --- /dev/null +++ b/lib/rspec/llama/helpers/executor.rb @@ -0,0 +1,58 @@ +# frozen_string_literal: true + +module Rspec + module Llama + module Helpers + module Executor + + def execute_test_run(name, prompt_id, assertion_id, model_version_id) + opts = prepare_test_run_data(name, prompt_id, assertion_id, model_version_id) + test_run = Rspec::Llama.api_client.execute_test_run(opts) + test_run_waiter(test_run['id']) + end + + private + + def test_run_waiter(test_run_id) + is_completed = false + is_failed = false + + spinner_thread = Thread.new do + spinner = %w[| / - \\] + while !is_completed + spinner.each do |frame| + print "\rProcessing...#{frame}" + sleep(0.2) + end + end + print "\r" and $stdout.flush + end + + loop do + test_results = fetch_test_results(test_run_id) + is_completed = !test_results.empty? && test_results.all? { |result| result['status'] == 'completed' } + fetch_assertion_results(test_results.last['id']) if is_completed + + break if is_completed + + sleep 1 + end + + spinner_thread.join + end + + def prepare_test_run_data(name, prompt_id, assertion_id, model_version_id) + opts = Hash.new { |hash, key| hash[key] = {} } + opts[:test_run][:name] = name + opts[:test_run][:prompt_id] = prompt_id + opts[:test_run][:calls] = 1 + opts[:test_run][:passing_threshold] = 0.5 + opts[:test_run][:assertion_ids] = [assertion_id] + opts[:test_run][:model_version_ids] = [model_version_id] + + opts + end + end + end + end +end diff --git a/lib/rspec/llama/helpers/resource_handler.rb b/lib/rspec/llama/helpers/resource_handler.rb new file mode 100644 index 0000000..b5f2ab9 --- /dev/null +++ b/lib/rspec/llama/helpers/resource_handler.rb @@ -0,0 +1,76 @@ +# frozen_string_literal: true + +module Rspec + module Llama + module Helpers + module ResourceHandler + + def use_model(name, version = nil) + models = Rspec::Llama.api_client.fetch_all_models + @model = models.find { |m| m['name'] == name } + use_model_version(@model['id'], version) if version + end + + def use_model_version(model_id, version) + model_versions = Rspec::Llama.api_client.fetch_all_model_versions(model_id) + @model_version = model_versions.find { |mv| mv['build_name'] == version } + end + + def use_prompt(name) + prompts = Rspec::Llama.api_client.fetch_all_prompts + @prompt = prompts.find { |p| p['name'] == name } + end + + def use_assertion(name) + assertions = Rspec::Llama.api_client.fetch_all_assertions + @assertion = assertions.find { |a| a['name'] == name } + end + + def use_test_run(test_run_id) + @test_run = Rspec::Llama.api_client.fetch_test_run(test_run_id) + end + + def use_test_model_version_run(test_model_version_run_id) + @model_version_run = Rspec::Llama.api_client.fetch_test_model_version_run(test_model_version_run_id) + end + + def create_model(name, version = nil) + @model = Rspec::Llama.api_client.create_model(name: name, url: 'http://host.docker.internal:8000/completion') + use_model_version(@model['id'], version) if version + @model_version.nil? ? create_model_version(@model['id'], version) : @model_version + end + + def create_model_version(model_id, name) + opts = { model_version: { configuration: { n_predict: 100, temperature: 0.8 }.to_json, description: '', built_on: Date.today, build_name: name } } + @model_version = Rspec::Llama.api_client.create_model_version(model_id, opts) + end + + def create_prompt(name, prompt, model = settled_model) + @prompt = Rspec::Llama.api_client.create_prompt(model['id'], name: name, value: prompt) + end + + def create_assertion(name, value, assertion_type = 'exclude', model = settled_model) + @assertion = Rspec::Llama.api_client.create_assertion(model, name: name, assertion_type: assertion_type, value: value) + end + + def fetch_test_results(test_run_id) + test_result_ids = Rspec::Llama.api_client.fetch_test_run(test_run_id)['test_result_ids'] + @test_results = test_result_ids.map { |id| Rspec::Llama.api_client.fetch_test_result(id) } + end + + def fetch_assertion_results(test_result_id) + @assertion_results = Rspec::Llama.api_client.fetch_assertion_results(test_result_id) + end + + %i[model prompt assertion model_version test_run test_results assertion_results].each do |method| + define_method("settled_#{method}") do |*args| + instance_variable_get("@#{method}").nil? ? raise("#{method} not settled") : instance_variable_get("@#{method}") + end + end + end + end + end +end + + + diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb index 1a98181..c0036c8 100644 --- a/spec/spec_helper.rb +++ b/spec/spec_helper.rb @@ -7,7 +7,7 @@ RSpec.configure do |config| # Enable flags like --only-failures and --next-failure config.api_endpoint = 'https://api.example.com' - config.auth_endpoint = 'https://api.example.com/authenticate' + config.auth_endpoint = 'https://api.example.com/users/sign_in' config.api_creds = { user: { email: 'your_username', password: 'your_password' } } config.example_status_persistence_file_path = '.rspec_status' From 0e34393819e5a9dbf60e953e842f9dcc3269575e Mon Sep 17 00:00:00 2001 From: Artur Antonnikau Date: Wed, 21 Aug 2024 23:19:56 +0200 Subject: [PATCH 2/2] Update DSL --- Gemfile | 1 + Gemfile.lock | 39 ++++++++-- examples/example_spec.rb | 70 +++++++++++++---- lib/rspec/llama.rb | 1 + lib/rspec/llama/api_client.rb | 4 +- lib/rspec/llama/helpers.rb | 6 ++ lib/rspec/llama/helpers/base.rb | 46 +++++++++++ lib/rspec/llama/helpers/errors.rb | 47 ++++++++++++ lib/rspec/llama/helpers/executor.rb | 28 +++++-- lib/rspec/llama/helpers/resource_handler.rb | 85 +++++++++++---------- lib/rspec/llama/support.rb | 11 +++ lib/rspec/llama/support/test_run.rb | 26 +++++++ 12 files changed, 292 insertions(+), 72 deletions(-) create mode 100644 lib/rspec/llama/helpers/base.rb create mode 100644 lib/rspec/llama/helpers/errors.rb create mode 100644 lib/rspec/llama/support.rb create mode 100644 lib/rspec/llama/support/test_run.rb diff --git a/Gemfile b/Gemfile index 147e6df..b6296ae 100644 --- a/Gemfile +++ b/Gemfile @@ -7,6 +7,7 @@ gemspec gem 'rake', '~> 12.0' gem 'rspec', '~> 3.0' +gem 'activesupport' group :development, :test do gem 'rubocop' gem 'webmock', '~> 3.0' diff --git a/Gemfile.lock b/Gemfile.lock index 22c8d74..7ba0105 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -9,32 +9,51 @@ PATH GEM remote: https://rubygems.org/ specs: + activesupport (7.2.0) + base64 + bigdecimal + concurrent-ruby (~> 1.0, >= 1.3.1) + connection_pool (>= 2.2.5) + drb + i18n (>= 1.6, < 2) + logger (>= 1.4.2) + minitest (>= 5.1) + securerandom (>= 0.3) + tzinfo (~> 2.0, >= 2.0.5) addressable (2.8.7) public_suffix (>= 2.0.2, < 7.0) ast (2.4.2) + base64 (0.2.0) bigdecimal (3.1.8) + concurrent-ruby (1.3.4) + connection_pool (2.4.1) crack (1.0.0) bigdecimal rexml diff-lcs (1.5.1) - hashdiff (1.1.0) + drb (2.2.1) + hashdiff (1.1.1) + i18n (1.14.5) + concurrent-ruby (~> 1.0) json (2.7.2) language_server-protocol (3.17.0.3) + logger (1.6.0) + minitest (5.25.1) net-http (0.1.1) net-protocol uri net-protocol (0.2.2) timeout - parallel (1.25.1) - parser (3.3.4.0) + parallel (1.26.3) + parser (3.3.4.2) ast (~> 2.4.1) racc - public_suffix (5.1.1) + public_suffix (6.0.1) racc (1.8.1) rainbow (3.1.1) rake (12.3.3) regexp_parser (2.9.2) - rexml (3.3.2) + rexml (3.3.5) strscan rspec (3.13.0) rspec-core (~> 3.13.0) @@ -42,14 +61,14 @@ GEM rspec-mocks (~> 3.13.0) rspec-core (3.13.0) rspec-support (~> 3.13.0) - rspec-expectations (3.13.1) + rspec-expectations (3.13.2) diff-lcs (>= 1.2.0, < 2.0) rspec-support (~> 3.13.0) rspec-mocks (3.13.1) diff-lcs (>= 1.2.0, < 2.0) rspec-support (~> 3.13.0) rspec-support (3.13.1) - rubocop (1.65.0) + rubocop (1.65.1) json (~> 2.3) language_server-protocol (>= 3.17.0) parallel (~> 1.10) @@ -60,11 +79,14 @@ GEM rubocop-ast (>= 1.31.1, < 2.0) ruby-progressbar (~> 1.7) unicode-display_width (>= 2.4.0, < 3.0) - rubocop-ast (1.31.3) + rubocop-ast (1.32.1) parser (>= 3.3.1.0) ruby-progressbar (1.13.0) + securerandom (0.3.1) strscan (3.1.0) timeout (0.4.1) + tzinfo (2.0.6) + concurrent-ruby (~> 1.0) unicode-display_width (2.5.0) uri (0.13.0) webmock (3.23.1) @@ -76,6 +98,7 @@ PLATFORMS ruby DEPENDENCIES + activesupport rake (~> 12.0) rspec (~> 3.0) rspec-llama! diff --git a/examples/example_spec.rb b/examples/example_spec.rb index bf83625..295ccee 100644 --- a/examples/example_spec.rb +++ b/examples/example_spec.rb @@ -11,14 +11,43 @@ end RSpec.describe 'Llama Rspec flow' do - let(:test_run_name) { 'fake_test_run_name' } let(:model_version_name) { 'Version 1' } + context 'when we want to fetch resources' do + before do + use_model('LLAMA_C', 'Vesion C') + build_prompt('pt_c', 'Is Minsk the capital of Belarus?') + build_assertion('1_ass', 'Yes') + use_prompt('pt_c') + use_assertion('1_ass') + end + + it 'should fetch model & model_version' do + expect(settled_model['name']).to eq('LLAMA_C') + expect(settled_model_version['build_name']).to eq('Vesion C') + expect(settled_prompt['name']).to eq('pt_c') + expect(settled_assertion['name']).to eq('1_ass') + end + end + + context 'when we want to create resources(prompt, assertion)' do + before do + use_model('LLAMA_C', 'Vesion C') + build_prompt('pt_s', 'Is Minsk the capital of Belarus?') + build_assertion('ass_1', 'Yes') + end + + it 'should create prompt & assertion' do + expect(settled_prompt['name']).to eq('pt_s') + expect(settled_assertion['name']).to eq('ass_1') + end + end + context 'when we want to set resources' do before do use_model('LLAMA', model_version_name) - use_prompt('pt_2') - use_assertion('ass_2') + build_prompt('pt_2', 'What is the capital of France?') + build_assertion('ass_2', 'No') end it 'should set model, prompt, assertion' do @@ -31,9 +60,9 @@ context 'when we want to create resources' do before do - create_model('LLAMA_C', 'Vesion C') - create_prompt('pt_c', 'What is the capital of France?') - create_assertion('ass_c', 'Paris') + use_model('LLAMA_C', 'Vesion C') + build_prompt('pt_c', 'What is the capital of France?') + build_assertion('ass_c', 'Paris') end it 'should create model, model_version, prompt, assertion' do @@ -51,15 +80,14 @@ use_assertion('ass_2') end - let(:action) { execute_test_run(test_run_name, settled_prompt['id'], settled_assertion['id'], settled_model_version['id']) } + let(:action) { execute_test_run } context 'model version 1' do let(:model_name) { 'LLAMA' } - let(:model_version_name) { 'Version 1' } + let(:model_version_name) { 'Version 2' } it 'should return a completed status' do action - expect(settled_test_results.last['status']).to eq('completed') - expect(settled_assertion_results.first['state']).to eq('passed') + expect(test_run).to be_successful end end @@ -69,23 +97,33 @@ it 'should return a completed status' do action - expect(settled_test_results.last['status']).to eq('completed') - expect(settled_assertion_results.first['state']).to eq('passed') + expect(test_run).to be_successful end context 'create a new prompt & assertion' do before do - create_prompt('pt_2', 'What is the capital of France?') - create_assertion('ass_2', 'Paris') + build_prompt('pt_2', 'what is the capital of france?') + build_assertion('ass_2', 'paris') + end + + it 'should return a completed status' do + action + expect(test_run).to be_failed + end + end + + context 'with more manual testing' do + before do + build_prompt('pt_2', 'what is the capital of france?') + build_assertion('ass_2', 'paris') end it 'should return a completed status' do action expect(settled_test_results.last['status']).to eq('completed') - expect(settled_assertion_results.first['state']).to eq('passed') + expect(settled_assertion_results.first['state']).to eq('failed') end end end end end - diff --git a/lib/rspec/llama.rb b/lib/rspec/llama.rb index ba97b39..8087ec5 100644 --- a/lib/rspec/llama.rb +++ b/lib/rspec/llama.rb @@ -8,3 +8,4 @@ require_relative 'llama/configuration' require_relative 'llama/version' require_relative 'llama/helpers' +require_relative 'llama/support/test_run' diff --git a/lib/rspec/llama/api_client.rb b/lib/rspec/llama/api_client.rb index 90bce2d..4f2e12d 100644 --- a/lib/rspec/llama/api_client.rb +++ b/lib/rspec/llama/api_client.rb @@ -73,13 +73,13 @@ def create_model_version(model_id, opts) api_execute(:post, uri, opts) end - def create_prompt(model_id, **opts) + def create_prompt(opts) params = { prompt: { value: opts[:value], name: opts[:name] } } uri = URI("#{endpoint}/prompts.json") api_execute(:post, uri, params) end - def create_assertion(model_id, **opts) + def create_assertion(opts) uri = URI("#{endpoint}/assertions.json") api_execute(:post, uri, opts) end diff --git a/lib/rspec/llama/helpers.rb b/lib/rspec/llama/helpers.rb index daf9faf..133d31a 100644 --- a/lib/rspec/llama/helpers.rb +++ b/lib/rspec/llama/helpers.rb @@ -1,13 +1,19 @@ # frozen_string_literal: true +require "active_support/all" + require_relative 'helpers/resource_handler' require_relative 'helpers/executor' +require_relative 'helpers/base' +require_relative 'helpers/errors' module Rspec module Llama module Helpers + include Base include ResourceHandler include Executor + include Errors end end end diff --git a/lib/rspec/llama/helpers/base.rb b/lib/rspec/llama/helpers/base.rb new file mode 100644 index 0000000..8493c0c --- /dev/null +++ b/lib/rspec/llama/helpers/base.rb @@ -0,0 +1,46 @@ +# frozen_string_literal: true + +module Rspec + module Llama + module Helpers + module Base + + def fetch_resource_with_error_handling(resource, *args, &block) + key = (resource == :model_version ? 'build_name' : 'name') + name = args.delete_at(0) + + resource_obj = if block_given? + yield + else + resources = Rspec::Llama.api_client.public_send("fetch_all_#{resource.to_s.pluralize}", *args) + resources.find { |r| r[key] == name } + end + raise Errors::ResourceNotFound.new(resource, name) unless resource_obj + + instance_variable_set("@#{resource}", resource_obj) + rescue StandardError => e + raise e + end + + def create_resource_with_error_handling(resource, args) + resource_obj = Rspec::Llama.api_client.public_send("create_#{resource}", args) || {} + resource_obj = if resource_obj['name'].is_a?(Array) + resource_obj['name'].last + else + resource_obj + end + + raise Errors::ResourceFailedToCreate.new(resource, args[:name]) if resource_obj.blank? + raise Errors::ResourceAlreadyCreated.new(resource, args[:name]) if resource_obj == "has already been taken" + + instance_variable_set("@#{resource}", resource_obj) + rescue Errors::ResourceAlreadyCreated => e + fetch_resource_with_error_handling(e.resource, e.name) + rescue StandardError => e + raise e + end + end + end + end +end + diff --git a/lib/rspec/llama/helpers/errors.rb b/lib/rspec/llama/helpers/errors.rb new file mode 100644 index 0000000..2ff6c4a --- /dev/null +++ b/lib/rspec/llama/helpers/errors.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +module Rspec + module Llama + module Helpers + module Errors + + class BaseResourceError < StandardError + attr_reader :resource, :name + def initialize(resource, name, message) + @resource = resource + @name = name + super(message) + end + end + + class ResourceAlreadyCreated < BaseResourceError + def initialize(resource, name) + msg = "#{resource.to_s.capitalize} with name #{name} has already been created" + super(resource, name, msg) + end + end + + class ResourceNotFound < BaseResourceError + def initialize(resource, name) + msg = "#{resource.to_s.capitalize} with name #{name} not found" + super(resource, name, msg) + end + end + + class ResourceFailedToCreate < BaseResourceError + def initialize(resource, name) + msg = "Failed to create #{resource.to_s.capitalize} with name #{name}" + super(resource, name, msg) + end + end + + class TestRunExecutionError < StandardError + def initialize(test_run) + msg = test_run.map { |k, v| "#{k.singularize} #{v.last}" }.join(', ') + super(msg) + end + end + end + end + end +end diff --git a/lib/rspec/llama/helpers/executor.rb b/lib/rspec/llama/helpers/executor.rb index 1f4c8eb..f8891a0 100644 --- a/lib/rspec/llama/helpers/executor.rb +++ b/lib/rspec/llama/helpers/executor.rb @@ -5,9 +5,15 @@ module Llama module Helpers module Executor - def execute_test_run(name, prompt_id, assertion_id, model_version_id) - opts = prepare_test_run_data(name, prompt_id, assertion_id, model_version_id) + def execute_test_run(prompt_id: nil, assertion_id: nil, model_version_id: nil) + prompt_id ||= settled_prompt['id'] + assertion_id ||= settled_assertion['id'] + model_version_id ||= settled_model_version['id'] + + opts = prepare_test_run_data(prompt_id, assertion_id, model_version_id) test_run = Rspec::Llama.api_client.execute_test_run(opts) + validate_test_run!(test_run) + test_run_waiter(test_run['id']) end @@ -21,7 +27,7 @@ def test_run_waiter(test_run_id) spinner = %w[| / - \\] while !is_completed spinner.each do |frame| - print "\rProcessing...#{frame}" + print "\r Processing...#{frame}" sleep(0.2) end end @@ -29,9 +35,9 @@ def test_run_waiter(test_run_id) end loop do - test_results = fetch_test_results(test_run_id) + test_results = retrieve_test_results(test_run_id) is_completed = !test_results.empty? && test_results.all? { |result| result['status'] == 'completed' } - fetch_assertion_results(test_results.last['id']) if is_completed + retrieve_assertion_results(test_results.last['id']) if is_completed break if is_completed @@ -41,9 +47,17 @@ def test_run_waiter(test_run_id) spinner_thread.join end - def prepare_test_run_data(name, prompt_id, assertion_id, model_version_id) + def test_run_name + "rspec_#{Time.now.strftime('%Y_%m_%d_%H_%M_%S')}" + end + + def validate_test_run!(test_run) + raise Errors::TestRunExecutionError.new(test_run) unless test_run&.[]('id') + end + + def prepare_test_run_data(prompt_id, assertion_id, model_version_id) opts = Hash.new { |hash, key| hash[key] = {} } - opts[:test_run][:name] = name + opts[:test_run][:name] = test_run_name opts[:test_run][:prompt_id] = prompt_id opts[:test_run][:calls] = 1 opts[:test_run][:passing_threshold] = 0.5 diff --git a/lib/rspec/llama/helpers/resource_handler.rb b/lib/rspec/llama/helpers/resource_handler.rb index b5f2ab9..2f6a315 100644 --- a/lib/rspec/llama/helpers/resource_handler.rb +++ b/lib/rspec/llama/helpers/resource_handler.rb @@ -5,72 +5,79 @@ module Llama module Helpers module ResourceHandler - def use_model(name, version = nil) - models = Rspec::Llama.api_client.fetch_all_models - @model = models.find { |m| m['name'] == name } - use_model_version(@model['id'], version) if version - end - - def use_model_version(model_id, version) - model_versions = Rspec::Llama.api_client.fetch_all_model_versions(model_id) - @model_version = model_versions.find { |mv| mv['build_name'] == version } + def use_model(model_name, version_name) + fetch_model(model_name) + fetch_model_version(version_name, settled_model['id']) end def use_prompt(name) - prompts = Rspec::Llama.api_client.fetch_all_prompts - @prompt = prompts.find { |p| p['name'] == name } + fetch_prompt(name) end def use_assertion(name) - assertions = Rspec::Llama.api_client.fetch_all_assertions - @assertion = assertions.find { |a| a['name'] == name } + fetch_assertion(name) end - def use_test_run(test_run_id) - @test_run = Rspec::Llama.api_client.fetch_test_run(test_run_id) + def build_prompt(name, prompt) + opts = { name: name, value: prompt, model_id: settled_model['id'] } + create_prompt(opts) end - def use_test_model_version_run(test_model_version_run_id) - @model_version_run = Rspec::Llama.api_client.fetch_test_model_version_run(test_model_version_run_id) + def build_assertion(name, value, assertion_type = 'exclude_all') + opts = { name: name, value: value, assertion_type: assertion_type, model_id: settled_model['id'] } + create_assertion(opts) end - def create_model(name, version = nil) - @model = Rspec::Llama.api_client.create_model(name: name, url: 'http://host.docker.internal:8000/completion') - use_model_version(@model['id'], version) if version - @model_version.nil? ? create_model_version(@model['id'], version) : @model_version - end - - def create_model_version(model_id, name) - opts = { model_version: { configuration: { n_predict: 100, temperature: 0.8 }.to_json, description: '', built_on: Date.today, build_name: name } } - @model_version = Rspec::Llama.api_client.create_model_version(model_id, opts) + def retrieve_test_run(test_run_id) + test_run = fetch_test_run(test_run_id) do + Rspec::Llama.api_client.fetch_test_run(test_run_id) + end end - def create_prompt(name, prompt, model = settled_model) - @prompt = Rspec::Llama.api_client.create_prompt(model['id'], name: name, value: prompt) + def retrieve_test_model_version_run(test_model_version_run_id) + fetch_test_model_version_run(test_model_version_run_id) do + Rspec::Llama.api_client.fetch_test_model_version_run(test_model_version_run_id) + end end - def create_assertion(name, value, assertion_type = 'exclude', model = settled_model) - @assertion = Rspec::Llama.api_client.create_assertion(model, name: name, assertion_type: assertion_type, value: value) + def retrieve_test_results(test_run_id) + fetch_test_results(test_run_id) do + test_result_ids = retrieve_test_run(test_run_id)['test_result_ids'] + test_result_ids.map { |id| Rspec::Llama.api_client.fetch_test_result(id) } + end end - def fetch_test_results(test_run_id) - test_result_ids = Rspec::Llama.api_client.fetch_test_run(test_run_id)['test_result_ids'] - @test_results = test_result_ids.map { |id| Rspec::Llama.api_client.fetch_test_result(id) } + def retrieve_assertion_results(test_result_id) + fetch_assertion_results(test_result_id) do + Rspec::Llama.api_client.fetch_assertion_results(test_result_id) + end end - def fetch_assertion_results(test_result_id) - @assertion_results = Rspec::Llama.api_client.fetch_assertion_results(test_result_id) + def test_run + Support::TestRun.new(settled_test_run['id'], settled_test_run['name'], settled_test_run['test_result_ids']) end %i[model prompt assertion model_version test_run test_results assertion_results].each do |method| - define_method("settled_#{method}") do |*args| + define_method("settled_#{method}") do |*_args| instance_variable_get("@#{method}").nil? ? raise("#{method} not settled") : instance_variable_get("@#{method}") end end + + private + + %i[model model_version prompt assertion test_run test_model_version_run test_results + assertion_results].each do |method| + define_method("fetch_#{method}") do |*args, &block| + fetch_resource_with_error_handling(method, *args, &block) + end + end + + %i[prompt assertion].each do |method| + define_method("create_#{method}") do |args| + create_resource_with_error_handling(method, args) + end + end end end end end - - - diff --git a/lib/rspec/llama/support.rb b/lib/rspec/llama/support.rb new file mode 100644 index 0000000..504f311 --- /dev/null +++ b/lib/rspec/llama/support.rb @@ -0,0 +1,11 @@ +# frozen_string_literal: true + +require_relative 'support/test_run' + +module Rspec + module Llama + module Support + include TestRun + end + end +end diff --git a/lib/rspec/llama/support/test_run.rb b/lib/rspec/llama/support/test_run.rb new file mode 100644 index 0000000..a889e1e --- /dev/null +++ b/lib/rspec/llama/support/test_run.rb @@ -0,0 +1,26 @@ +# frozen_string_literal: true + +module Rspec + module Llama + module Support + TestRun = Struct.new(:id, :name, :test_result_ids) do + def latest_test_result + test_result_ids.map { |id| Rspec::Llama.api_client.fetch_test_result(id) }.last + end + + def latest_assertion_result + Rspec::Llama.api_client.fetch_assertion_results(latest_test_result['id']).last + end + + def successful? + latest_assertion_result['state'] == 'passed' + end + + def failed? + latest_assertion_result['state'] == 'failed' + end + end + end + end +end +