Skip to content
This repository was archived by the owner on Jul 29, 2025. It is now read-only.

Commit a2524a1

Browse files
Pietjankujtimiihoxha
authored andcommitted
Add local provider
1 parent 94d5fe0 commit a2524a1

File tree

4 files changed

+202
-3
lines changed

4 files changed

+202
-3
lines changed

internal/config/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ func validateAgent(cfg *Config, name AgentName, agent Agent) error {
526526
}
527527

528528
// Validate reasoning effort for models that support reasoning
529-
if model.CanReason && provider == models.ProviderOpenAI {
529+
if model.CanReason && provider == models.ProviderOpenAI || provider == models.ProviderLocal {
530530
if agent.ReasoningEffort == "" {
531531
// Set default reasoning effort for models that support it
532532
logging.Info("setting default reasoning effort for model that supports reasoning",

internal/llm/agent/agent.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ func createAgentProvider(agentName config.AgentName) (provider.Provider, error)
715715
provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
716716
provider.WithMaxTokens(maxTokens),
717717
}
718-
if model.Provider == models.ProviderOpenAI && model.CanReason {
718+
if model.Provider == models.ProviderOpenAI || model.Provider == models.ProviderLocal && model.CanReason {
719719
opts = append(
720720
opts,
721721
provider.WithOpenAIOptions(

internal/llm/models/local.go

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package models
2+
3+
import (
4+
"cmp"
5+
"encoding/json"
6+
"log/slog"
7+
"net/http"
8+
"net/url"
9+
"os"
10+
"regexp"
11+
"strings"
12+
"unicode"
13+
14+
"github.com/spf13/viper"
15+
)
16+
17+
const (
18+
ProviderLocal ModelProvider = "local"
19+
20+
localModelsPath = "v1/models"
21+
lmStudioBetaModelsPath = "api/v0/models"
22+
)
23+
24+
func init() {
25+
if endpoint := os.Getenv("LOCAL_ENDPOINT"); endpoint != "" {
26+
localEndpoint, err := url.Parse(endpoint)
27+
if err != nil {
28+
slog.Debug("Failed to parse local endpoint",
29+
"error", err,
30+
"endpoint", endpoint,
31+
)
32+
return
33+
}
34+
35+
load := func(url *url.URL, path string) []localModel {
36+
url.Path = path
37+
return listLocalModels(url.String())
38+
}
39+
40+
models := load(localEndpoint, lmStudioBetaModelsPath)
41+
42+
if len(models) == 0 {
43+
models = load(localEndpoint, localModelsPath)
44+
}
45+
46+
if len(models) == 0 {
47+
slog.Debug("No local models found",
48+
"endpoint", endpoint,
49+
)
50+
return
51+
}
52+
53+
loadLocalModels(models)
54+
55+
viper.SetDefault("providers.local.apiKey", "dummy")
56+
ProviderPopularity[ProviderLocal] = 0
57+
}
58+
}
59+
60+
type localModelList struct {
61+
Data []localModel `json:"data"`
62+
}
63+
64+
type localModel struct {
65+
ID string `json:"id"`
66+
Object string `json:"object"`
67+
Type string `json:"type"`
68+
Publisher string `json:"publisher"`
69+
Arch string `json:"arch"`
70+
CompatibilityType string `json:"compatibility_type"`
71+
Quantization string `json:"quantization"`
72+
State string `json:"state"`
73+
MaxContextLength int64 `json:"max_context_length"`
74+
LoadedContextLength int64 `json:"loaded_context_length"`
75+
}
76+
77+
func listLocalModels(modelsEndpoint string) []localModel {
78+
res, err := http.Get(modelsEndpoint)
79+
if err != nil {
80+
slog.Debug("Failed to list local models",
81+
"error", err,
82+
"endpoint", modelsEndpoint,
83+
)
84+
}
85+
defer res.Body.Close()
86+
87+
if res.StatusCode != http.StatusOK {
88+
slog.Debug("Failed to list local models",
89+
"status", res.StatusCode,
90+
"endpoint", modelsEndpoint,
91+
)
92+
}
93+
94+
var modelList localModelList
95+
if err = json.NewDecoder(res.Body).Decode(&modelList); err != nil {
96+
slog.Debug("Failed to list local models",
97+
"error", err,
98+
"endpoint", modelsEndpoint,
99+
)
100+
}
101+
102+
var supportedModels []localModel
103+
for _, model := range modelList.Data {
104+
if strings.HasSuffix(modelsEndpoint, lmStudioBetaModelsPath) {
105+
if model.Object != "model" || model.Type != "llm" {
106+
slog.Debug("Skipping unsupported LMStudio model",
107+
"endpoint", modelsEndpoint,
108+
"id", model.ID,
109+
"object", model.Object,
110+
"type", model.Type,
111+
)
112+
113+
continue
114+
}
115+
}
116+
117+
supportedModels = append(supportedModels, model)
118+
}
119+
120+
return supportedModels
121+
}
122+
123+
func loadLocalModels(models []localModel) {
124+
for i, m := range models {
125+
model := convertLocalModel(m)
126+
SupportedModels[model.ID] = model
127+
128+
if i == 1 || m.State == "loaded" {
129+
viper.SetDefault("agents.coder.model", model.ID)
130+
viper.SetDefault("agents.summarizer.model", model.ID)
131+
viper.SetDefault("agents.task.model", model.ID)
132+
viper.SetDefault("agents.title.model", model.ID)
133+
}
134+
}
135+
}
136+
137+
func convertLocalModel(model localModel) Model {
138+
return Model{
139+
ID: ModelID("local." + model.ID),
140+
Name: friendlyModelName(model.ID),
141+
Provider: ProviderLocal,
142+
APIModel: model.ID,
143+
ContextWindow: cmp.Or(model.LoadedContextLength, 4096),
144+
DefaultMaxTokens: cmp.Or(model.LoadedContextLength, 4096),
145+
CanReason: true,
146+
SupportsAttachments: true,
147+
}
148+
}
149+
150+
var modelInfoRegex = regexp.MustCompile(`(?i)^([a-z0-9]+)(?:[-_]?([rv]?\d[\.\d]*))?(?:[-_]?([a-z]+))?.*`)
151+
152+
func friendlyModelName(modelID string) string {
153+
match := modelInfoRegex.FindStringSubmatch(modelID)
154+
if match == nil {
155+
return modelID
156+
}
157+
158+
capitalize := func(s string) string {
159+
if s == "" {
160+
return ""
161+
}
162+
runes := []rune(s)
163+
runes[0] = unicode.ToUpper(runes[0])
164+
return string(runes)
165+
}
166+
167+
family := capitalize(match[1])
168+
version := ""
169+
label := ""
170+
171+
if len(match) > 2 && match[2] != "" {
172+
version = strings.ToUpper(match[2])
173+
}
174+
175+
if len(match) > 3 && match[3] != "" {
176+
label = capitalize(match[3])
177+
}
178+
179+
var parts []string
180+
if family != "" {
181+
parts = append(parts, family)
182+
}
183+
if version != "" {
184+
parts = append(parts, version)
185+
}
186+
if label != "" {
187+
parts = append(parts, label)
188+
}
189+
190+
return strings.Join(parts, " ")
191+
}

internal/llm/provider/provider.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package provider
33
import (
44
"context"
55
"fmt"
6+
"os"
67

78
"github.com/opencode-ai/opencode/internal/llm/models"
89
"github.com/opencode-ai/opencode/internal/llm/tools"
@@ -145,7 +146,14 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption
145146
options: clientOptions,
146147
client: newOpenAIClient(clientOptions),
147148
}, nil
148-
149+
case models.ProviderLocal:
150+
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
151+
WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")),
152+
)
153+
return &baseProvider[OpenAIClient]{
154+
options: clientOptions,
155+
client: newOpenAIClient(clientOptions),
156+
}, nil
149157
case models.ProviderMock:
150158
// TODO: implement mock client for test
151159
panic("not implemented")

0 commit comments

Comments
 (0)