Skip to content

Add/formatting chat template #568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/renderer/components/Data/DatasetPreviewWithTemplate.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const fetcher = (url) =>
.then((res) => res.json())
.then((data) => data);

const DatasetTableWithTemplate = ({ datasetId, template }) => {
const DatasetTableWithTemplate = ({ datasetId, template, modelName = '' }) => {
const [pageNumber, setPageNumber] = useState(1);
const [numOfPages, setNumOfPages] = useState(1);
const [datasetLen, setDatasetLen] = useState(null);
Expand All @@ -41,6 +41,7 @@ const DatasetTableWithTemplate = ({ datasetId, template }) => {
encodeURIComponent(template),
offset,
pageSize,
modelName,
),
fetcher,
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import {
FormLabel,
Textarea,
Typography,
Switch,
Stack,
} from '@mui/joy';
import { InfoIcon } from 'lucide-react';
import { useState, useEffect } from 'react';
Expand All @@ -16,6 +18,7 @@ import DatasetTable from 'renderer/components/Data/DatasetTable';
import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
import useSWR from 'swr';
import { useDebounce } from 'use-debounce';
import { useAPI } from 'renderer/lib/transformerlab-api-sdk';

const fetcher = (url) => fetch(url).then((res) => res.json());

Expand All @@ -28,8 +31,11 @@ function TrainingModalDataTemplatingTab({
pluginId,
}) {
const [template, setTemplate] = useState(
'Instruction: Summarize the Following\nPrompt: {{dialogue}}\nGeneration: {{summary}}'
'Instruction: Summarize the Following\nPrompt: {{dialogue}}\nGeneration: {{summary}}',
);
const [applyChatTemplate, setApplyChatTemplate] = useState(false);
const [chatTemplate, setChatTemplate] = useState('');

useEffect(() => {
//initialize the template with the saved value
if (templateData?.config?.formatting_template) {
Expand All @@ -43,12 +49,31 @@ function TrainingModalDataTemplatingTab({
chatAPI.Endpoints.Experiment.ScriptGetFile(
experimentInfo?.id,
pluginId,
'index.json'
'index.json',
),
fetcher
fetcher,
);

const {
data: chatTemplateData,
error: chatTemplateError,
isLoading: isChatTemplateLoading,
} = useAPI(
'models',
['chatTemplate'],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This throws an error for me?
Monosnap Transformer Lab 2025-06-24 10-43-50

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the api repo up to date with main?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I think I needed to merge something else too then!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dadmobile now it should work

{ modelName: experimentInfo?.config?.foundation },
{ enabled: !!applyChatTemplate && !!experimentInfo?.config?.foundation },
);

useEffect(() => {
if (applyChatTemplate && chatTemplateData?.data) {
setChatTemplate(chatTemplateData.data);
}
}, [applyChatTemplate, chatTemplateData]);

const [debouncedTemplate] = useDebounce(template, 3000);
const [debouncedChatTemplate] = useDebounce(chatTemplate, 3000);

let parsedData;

try {
Expand All @@ -63,7 +88,8 @@ function TrainingModalDataTemplatingTab({
<>
<Typography level="title-lg" mt={2}>
Preview Templated Output:{' '}
{template != debouncedTemplate && (
{(template != debouncedTemplate ||
chatTemplate != debouncedChatTemplate) && (
<CircularProgress
color="neutral"
variant="plain"
Expand All @@ -82,6 +108,9 @@ function TrainingModalDataTemplatingTab({
<DatasetTableWithTemplate
datasetId={selectedDataset}
template={debouncedTemplate}
modelName={
applyChatTemplate ? experimentInfo?.config?.foundation : ''
}
/>
</>
);
Expand Down Expand Up @@ -155,6 +184,32 @@ function TrainingModalDataTemplatingTab({
);
case 'none':
return <>No data template is required for this trainer</>;
case 'missing_chat':
return (
<>
No configuration data available for this model. This may happen with
local models.
</>
);
case 'chat':
return (
<>
<FormControl>
<textarea
required
name="formatting_chat_template"
id="chat_template"
rows={10}
value={chatTemplate}
//onChange={(e) => setChatTemplate(e.target.value)}
/>
<FormHelperText>
This template is fetched from the model's tokenizer config.
</FormHelperText>
</FormControl>
{selectedDataset && <PreviewSection />}
</>
);
default:
return (
<>
Expand Down Expand Up @@ -183,6 +238,13 @@ function TrainingModalDataTemplatingTab({
flexDirection: 'column',
}}
>
<Stack direction="row" spacing={2} alignItems="center" mb={2}>
<Switch
checked={applyChatTemplate}
onChange={(e) => setApplyChatTemplate(e.target.checked)}
/>
<Typography level="body-md">Apply Chat Template</Typography>
</Stack>
{parsedData?.training_template_format !== 'none' && (
<>
<Alert sx={{ mt: 1 }} color="danger">
Expand Down Expand Up @@ -211,36 +273,43 @@ function TrainingModalDataTemplatingTab({
</>
))}
</Box>

{selectedDataset && (
<FormHelperText>
Use the field names above, surrounded by
&#123;&#123;&#125;&#125; in the template below
</FormHelperText>
{!applyChatTemplate && (
<>
{selectedDataset && (
<FormHelperText>
Use the field names above, surrounded by
&#123;&#123;&#125;&#125; in the template below
</FormHelperText>
)}
<FormHelperText
sx={{ flexDirection: 'column', alignItems: 'flex-start' }}
>
The formatting template describes how the data is formatted
when passed to the trainer. Use the Jinja2 Standard String
Templating format. For example:
<br />
<span style={{}}>
Summarize the following:
<br />
Prompt: &#123;&#123;prompt&#125;&#125;
<br />
Generation: &#123;&#123;generation&#125;&#125;
</span>
</FormHelperText>
</>
)}
<FormHelperText
sx={{ flexDirection: 'column', alignItems: 'flex-start' }}
>
The formatting template describes how the data is formatted when
passed to the trainer. Use the Jinja2 Standard String Templating
format. For example:
<br />
<span style={{}}>
Summarize the following:
<br />
Prompt: &#123;&#123;prompt&#125;&#125;
<br />
Generation: &#123;&#123;generation&#125;&#125;
</span>
</FormHelperText>
</FormControl>
</Alert>
</>
)}
<Typography level="title-lg" mt={2} mb={0.5}>
Formatting Template
</Typography>
{renderTemplate(parsedData?.training_template_format)}
{applyChatTemplate
? chatTemplateData?.data
? renderTemplate('chat')
: renderTemplate('missing_chat')
: renderTemplate('parsedData?.training_template_format')}
</Box>
);
}
Expand Down
4 changes: 4 additions & 0 deletions src/renderer/lib/api-client/allEndpoints.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@
"installPeft": {
"method": "POST",
"path": "model/install_peft?peft={peft}&model_id={modelId}"
},
"chatTemplate": {
"method": "GET",
"path": "model/chat_template?model_name={modelName}"
}
},
"jobs": {
Expand Down
33 changes: 27 additions & 6 deletions src/renderer/lib/api-client/endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Endpoints.Tasks = {
};

Endpoints.Workflows = {
ListInExperiment: (experimentId: string) =>
ListInExperiment: (experimentId: string) =>
`${API_URL()}experiment/${experimentId}/workflows/list`,
CreateEmpty: (name: string, experimentId: string) =>
`${API_URL()}experiment/${experimentId}/workflows/create_empty` +
Expand All @@ -36,16 +36,36 @@ Endpoints.Workflows = {
`${API_URL()}experiment/${experimentId}/workflows/${workflowId}/add_node?node=${node}`,
DeleteNode: (workflowId: string, nodeId: string, experimentId: string) =>
`${API_URL()}experiment/${experimentId}/workflows/${workflowId}/${nodeId}/delete_node`,
UpdateNode: (workflowId: string, nodeId: string, node: string, experimentId: string) =>
UpdateNode: (
workflowId: string,
nodeId: string,
node: string,
experimentId: string,
) =>
`${API_URL()}experiment/${experimentId}/workflows/${workflowId}/${nodeId}/update_node` +
`?node=${node}`,
EditNodeMetadata: (workflowId: string, nodeId: string, metadata: string, experimentId: string) =>
EditNodeMetadata: (
workflowId: string,
nodeId: string,
metadata: string,
experimentId: string,
) =>
`${API_URL()}experiment/${experimentId}/workflows/${workflowId}/${nodeId}/edit_node_metadata` +
`?metadata=${metadata}`,
AddEdge: (workflowId: string, from: string, to: string, experimentId: string) =>
AddEdge: (
workflowId: string,
from: string,
to: string,
experimentId: string,
) =>
`${API_URL()}experiment/${experimentId}/workflows/${workflowId}/${from}/add_edge` +
`?end_node_id=${to}`,
RemoveEdge: (workflowId: string, start_node_id: string, to: string, experimentId: string) =>
RemoveEdge: (
workflowId: string,
start_node_id: string,
to: string,
experimentId: string,
) =>
`${API_URL()}experiment/${experimentId}/workflows/${workflowId}/${start_node_id}/remove_edge` +
`?end_node_id=${to}`,
RunWorkflow: (workflowId: string, experimentId: string) =>
Expand Down Expand Up @@ -79,10 +99,11 @@ Endpoints.Dataset = {
template: string,
offset: number,
limit: number,
modelName: string = '',
) =>
`${API_URL()}data/preview_with_template?dataset_id=${datasetId}&template=${
template
}&offset=${offset}&limit=${limit}`,
}&offset=${offset}&limit=${limit}&model_name=${modelName}`,
Delete: (datasetId: string) =>
`${API_URL()}data/delete?dataset_id=${datasetId}`,
Create: (datasetId: string) => `${API_URL()}data/new?dataset_id=${datasetId}`,
Expand Down