Skip to content

Commit d53db20

Browse files
committed
feat(sdk): support thinking controll for model invoking
1 parent a4fac1b commit d53db20

File tree

3 files changed

+84
-44
lines changed

3 files changed

+84
-44
lines changed

packages/ui-tars/sdk/src/GUIAgent.ts

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ export class GUIAgent<T extends Operator> extends BaseGUIAgent<
108108

109109
let loopCnt = 0;
110110
let snapshotErrCnt = 0;
111+
let totalTokens = 0;
112+
let totalTime = 0;
111113

112114
// start running agent
113115
data.status = StatusEnum.RUNNING;
@@ -240,44 +242,51 @@ export class GUIAgent<T extends Operator> extends BaseGUIAgent<
240242
scaleFactor: snapshot.scaleFactor,
241243
uiTarsVersion: this.uiTarsVersion,
242244
};
243-
const { prediction, parsedPredictions } = await asyncRetry(
244-
async (bail) => {
245-
try {
246-
const result = await model.invoke(vlmParams);
247-
return result;
248-
} catch (error: unknown) {
249-
if (
250-
error instanceof Error &&
251-
(error?.name === 'APIUserAbortError' ||
252-
error?.message?.includes('aborted'))
253-
) {
254-
bail(error as unknown as Error);
245+
const { prediction, parsedPredictions, costTime, costTokens } =
246+
await asyncRetry(
247+
async (bail) => {
248+
try {
249+
const result = await model.invoke(vlmParams);
250+
return result;
251+
} catch (error: unknown) {
252+
if (
253+
error instanceof Error &&
254+
(error?.name === 'APIUserAbortError' ||
255+
error?.message?.includes('aborted'))
256+
) {
257+
bail(error as unknown as Error);
258+
return {
259+
prediction: '',
260+
parsedPredictions: [],
261+
};
262+
}
263+
264+
Object.assign(data, {
265+
status: StatusEnum.ERROR,
266+
error: this.guiAgentErrorParser(
267+
ErrorStatusEnum.INVOKE_RETRY_ERROR,
268+
error as Error,
269+
),
270+
});
271+
255272
return {
256273
prediction: '',
257274
parsedPredictions: [],
258275
};
259276
}
277+
},
278+
{
279+
retries: retry?.model?.maxRetries ?? 0,
280+
onRetry: retry?.model?.onRetry,
281+
},
282+
);
260283

261-
Object.assign(data, {
262-
status: StatusEnum.ERROR,
263-
error: this.guiAgentErrorParser(
264-
ErrorStatusEnum.INVOKE_RETRY_ERROR,
265-
error as Error,
266-
),
267-
});
284+
totalTokens += costTokens || 0;
285+
totalTime += costTime || 0;
268286

269-
return {
270-
prediction: '',
271-
parsedPredictions: [],
272-
};
273-
}
274-
},
275-
{
276-
retries: retry?.model?.maxRetries ?? 0,
277-
onRetry: retry?.model?.onRetry,
278-
},
287+
logger.info(
288+
`[GUIAgent] consumes: >>> costTime: ${costTime}, costTokens: ${costTokens} <<<`,
279289
);
280-
281290
logger.info('[GUIAgent] Response:', prediction);
282291
logger.info(
283292
'[GUIAgent] Parsed Predictions:',
@@ -450,6 +459,10 @@ export class GUIAgent<T extends Operator> extends BaseGUIAgent<
450459
),
451460
});
452461
}
462+
463+
logger.info(
464+
`[GUIAgent] >>> totalTokens: ${totalTokens}, totalTime: ${totalTime}, loopCnt: ${loopCnt} <<<`,
465+
);
453466
}
454467
}
455468

packages/ui-tars/sdk/src/Model.ts

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55
import OpenAI, { type ClientOptions } from 'openai';
66
import {
7+
type ChatCompletionCreateParamsNonStreaming,
78
type ChatCompletionCreateParamsBase,
89
type ChatCompletionMessageParam,
910
} from 'openai/resources/chat/completions';
@@ -29,6 +30,13 @@ type OpenAIChatCompletionCreateParams = Omit<ClientOptions, 'maxRetries'> &
2930

3031
export interface UITarsModelConfig extends OpenAIChatCompletionCreateParams {}
3132

33+
export interface ThinkingVisionProModelConfig
34+
extends ChatCompletionCreateParamsNonStreaming {
35+
thinking?: {
36+
type: 'enabled' | 'disabled';
37+
};
38+
}
39+
3240
export class UITarsModel extends Model {
3341
constructor(protected readonly modelConfig: UITarsModelConfig) {
3442
super();
@@ -60,6 +68,8 @@ export class UITarsModel extends Model {
6068
},
6169
): Promise<{
6270
prediction: string;
71+
costTime?: number;
72+
costTokens?: number;
6373
}> {
6474
const { messages } = params;
6575
const {
@@ -79,25 +89,38 @@ export class UITarsModel extends Model {
7989
apiKey,
8090
});
8191

82-
const result = await openai.chat.completions.create(
83-
{
84-
model,
85-
messages,
86-
stream: false,
87-
seed: null,
88-
stop: null,
89-
frequency_penalty: null,
90-
presence_penalty: null,
91-
// custom options
92-
max_tokens,
93-
temperature,
94-
top_p,
92+
const createCompletionPrams: ChatCompletionCreateParamsNonStreaming = {
93+
model,
94+
messages,
95+
stream: false,
96+
seed: null,
97+
stop: null,
98+
frequency_penalty: null,
99+
presence_penalty: null,
100+
// custom options
101+
max_tokens,
102+
temperature,
103+
top_p,
104+
};
105+
106+
const createCompletionPramsThinkingVp: ThinkingVisionProModelConfig = {
107+
...createCompletionPrams,
108+
thinking: {
109+
type: 'disabled',
95110
},
111+
};
112+
113+
const startTime = Date.now();
114+
const result = await openai.chat.completions.create(
115+
createCompletionPramsThinkingVp,
96116
options,
97117
);
118+
const costTime = Date.now() - startTime;
98119

99120
return {
100121
prediction: result.choices?.[0]?.message?.content ?? '',
122+
costTime: costTime,
123+
costTokens: result.usage?.total_tokens ?? 0,
101124
};
102125
}
103126

@@ -152,7 +175,7 @@ export class UITarsModel extends Model {
152175
throw err;
153176
}
154177

155-
const { prediction } = result;
178+
const { prediction, costTime, costTokens } = result;
156179

157180
try {
158181
const { parsed: parsedPredictions } = await actionParser({
@@ -165,6 +188,8 @@ export class UITarsModel extends Model {
165188
return {
166189
prediction,
167190
parsedPredictions,
191+
costTime,
192+
costTokens,
168193
};
169194
} catch (error) {
170195
logger?.error('[UITarsModel] error', error);

packages/ui-tars/sdk/src/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ export interface InvokeParams {
5252
export interface InvokeOutput {
5353
prediction: string;
5454
parsedPredictions: PredictionParsed[];
55+
costTime?: number;
56+
costTokens?: number;
5557
// TODO: status: StatusEnum, status should be provided by model
5658
}
5759
export abstract class Operator extends BaseOperator {

0 commit comments

Comments
 (0)