|
| 1 | +package models |
| 2 | + |
| 3 | +import ( |
| 4 | + "cmp" |
| 5 | + "encoding/json" |
| 6 | + "log/slog" |
| 7 | + "net/http" |
| 8 | + "net/url" |
| 9 | + "os" |
| 10 | + "regexp" |
| 11 | + "strings" |
| 12 | + "unicode" |
| 13 | + |
| 14 | + "github.com/spf13/viper" |
| 15 | +) |
| 16 | + |
| 17 | +const ( |
| 18 | + ProviderLocal ModelProvider = "local" |
| 19 | + |
| 20 | + localModelsPath = "v1/models" |
| 21 | + lmStudioBetaModelsPath = "api/v0/models" |
| 22 | +) |
| 23 | + |
| 24 | +func init() { |
| 25 | + if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" { |
| 26 | + localEndpoint, err := url.Parse(endpoint) |
| 27 | + if err != nil { |
| 28 | + slog.Debug("Failed to parse local endpoint", |
| 29 | + "error", err, |
| 30 | + "endpoint", endpoint, |
| 31 | + ) |
| 32 | + return |
| 33 | + } |
| 34 | + |
| 35 | + load := func(url *url.URL, path string) []localModel { |
| 36 | + url.Path = path |
| 37 | + return listLocalModels(url.String()) |
| 38 | + } |
| 39 | + |
| 40 | + models := load(localEndpoint, lmStudioBetaModelsPath) |
| 41 | + |
| 42 | + if len(models) == 0 { |
| 43 | + models = load(localEndpoint, localModelsPath) |
| 44 | + } |
| 45 | + |
| 46 | + if len(models) == 0 { |
| 47 | + slog.Debug("No local models found", |
| 48 | + "endpoint", endpoint, |
| 49 | + ) |
| 50 | + return |
| 51 | + } |
| 52 | + |
| 53 | + loadLocalModels(models) |
| 54 | + |
| 55 | + viper.SetDefault("providers.local.apiKey", "dummy") |
| 56 | + ProviderPopularity[ProviderLocal] = 0 |
| 57 | + } |
| 58 | +} |
| 59 | + |
| 60 | +type localModelList struct { |
| 61 | + Data []localModel `json:"data"` |
| 62 | +} |
| 63 | + |
| 64 | +type localModel struct { |
| 65 | + ID string `json:"id"` |
| 66 | + Object string `json:"object"` |
| 67 | + Type string `json:"type"` |
| 68 | + Publisher string `json:"publisher"` |
| 69 | + Arch string `json:"arch"` |
| 70 | + CompatibilityType string `json:"compatibility_type"` |
| 71 | + Quantization string `json:"quantization"` |
| 72 | + State string `json:"state"` |
| 73 | + MaxContextLength int64 `json:"max_context_length"` |
| 74 | + LoadedContextLength int64 `json:"loaded_context_length"` |
| 75 | +} |
| 76 | + |
| 77 | +func listLocalModels(modelsEndpoint string) []localModel { |
| 78 | + res, err := http.Get(modelsEndpoint) |
| 79 | + if err != nil { |
| 80 | + slog.Debug("Failed to list local models", |
| 81 | + "error", err, |
| 82 | + "endpoint", modelsEndpoint, |
| 83 | + ) |
| 84 | + } |
| 85 | + defer res.Body.Close() |
| 86 | + |
| 87 | + if res.StatusCode != http.StatusOK { |
| 88 | + slog.Debug("Failed to list local models", |
| 89 | + "status", res.StatusCode, |
| 90 | + "endpoint", modelsEndpoint, |
| 91 | + ) |
| 92 | + } |
| 93 | + |
| 94 | + var modelList localModelList |
| 95 | + if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil { |
| 96 | + slog.Debug("Failed to list local models", |
| 97 | + "error", err, |
| 98 | + "endpoint", modelsEndpoint, |
| 99 | + ) |
| 100 | + } |
| 101 | + |
| 102 | + var supportedModels []localModel |
| 103 | + for _, model := range modelList.Data { |
| 104 | + if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) { |
| 105 | + if model.Object != "model" || model.Type != "llm" { |
| 106 | + slog.Debug("Skipping unsupported LMStudio model", |
| 107 | + "endpoint", modelsEndpoint, |
| 108 | + "id", model.ID, |
| 109 | + "object", model.Object, |
| 110 | + "type", model.Type, |
| 111 | + ) |
| 112 | + |
| 113 | + continue |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + supportedModels = append(supportedModels, model) |
| 118 | + } |
| 119 | + |
| 120 | + return supportedModels |
| 121 | +} |
| 122 | + |
| 123 | +func loadLocalModels(models []localModel) { |
| 124 | + for i, m := range models { |
| 125 | + model := convertLocalModel(m) |
| 126 | + SupportedModels[model.ID] = model |
| 127 | + |
| 128 | + if i == 1 || m.State == "loaded" { |
| 129 | + viper.SetDefault("agents.coder.model", model.ID) |
| 130 | + viper.SetDefault("agents.summarizer.model", model.ID) |
| 131 | + viper.SetDefault("agents.task.model", model.ID) |
| 132 | + viper.SetDefault("agents.title.model", model.ID) |
| 133 | + } |
| 134 | + } |
| 135 | +} |
| 136 | + |
| 137 | +func convertLocalModel(model localModel) Model { |
| 138 | + return Model{ |
| 139 | + ID: ModelID("local." + model.ID), |
| 140 | + Name: friendlyModelName(model.ID), |
| 141 | + Provider: ProviderLocal, |
| 142 | + APIModel: model.ID, |
| 143 | + ContextWindow: cmp.Or(model.LoadedContextLength, 4096), |
| 144 | + DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096), |
| 145 | + CanReason: true, |
| 146 | + SupportsAttachments: true, |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`) |
| 151 | + |
| 152 | +func friendlyModelName(modelID string) string { |
| 153 | + match := modelInfoRegex.FindStringSubmatch(modelID) |
| 154 | + if match == nil { |
| 155 | + return modelID |
| 156 | + } |
| 157 | + |
| 158 | + capitalize := func(s string) string { |
| 159 | + if s == "" { |
| 160 | + return "" |
| 161 | + } |
| 162 | + runes := []rune(s) |
| 163 | + runes[0] = unicode.ToUpper(runes[0]) |
| 164 | + return string(runes) |
| 165 | + } |
| 166 | + |
| 167 | + family := capitalize(match[1]) |
| 168 | + version := "" |
| 169 | + label := "" |
| 170 | + |
| 171 | + if len(match) > 2 && match[2] != "" { |
| 172 | + version = strings.ToUpper(match[2]) |
| 173 | + } |
| 174 | + |
| 175 | + if len(match) > 3 && match[3] != "" { |
| 176 | + label = capitalize(match[3]) |
| 177 | + } |
| 178 | + |
| 179 | + var parts []string |
| 180 | + if family != "" { |
| 181 | + parts = append(parts, family) |
| 182 | + } |
| 183 | + if version != "" { |
| 184 | + parts = append(parts, version) |
| 185 | + } |
| 186 | + if label != "" { |
| 187 | + parts = append(parts, label) |
| 188 | + } |
| 189 | + |
| 190 | + return strings.Join(parts, " ") |
| 191 | +} |
0 commit comments