Skip to content

Commit afb499f

Browse files
committed
include tool calling UI from posit-dev/shinychat#52
1 parent 25fa285 commit afb499f

File tree

2 files changed

+70
-39
lines changed

2 files changed

+70
-39
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Imports:
2222
promises,
2323
R6,
2424
rlang,
25+
S7,
2526
shiny (>= 1.10.0),
2627
shinychat,
2728
tibble,

R/app.R

Lines changed: 69 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,23 @@
55
NULL
66

77
html_deps <- function() {
8-
htmltools::htmlDependency(
9-
"modelbot",
10-
utils::packageVersion("modelbot"),
11-
src = "www",
12-
package = "modelbot",
13-
stylesheet = "style.css"
8+
list(
9+
htmltools::htmlDependency(
10+
"modelbot",
11+
utils::packageVersion("modelbot"),
12+
src = "www",
13+
package = "modelbot",
14+
stylesheet = "style.css"
15+
),
16+
# Tool calling UI dependencies from shinychat
17+
htmltools::htmlDependency(
18+
"shinychat-tools",
19+
utils::packageVersion("shinychat"),
20+
src = "tools",
21+
package = "shinychat",
22+
stylesheet = "tool-request.css",
23+
script = "tool-request.js"
24+
)
1425
)
1526
}
1627

@@ -50,6 +61,7 @@ model_bot <- function(new_session = FALSE) {
5061
}
5162
})
5263

64+
chat <- modelbot_client(default_turns = globals$turns)
5365
restored_since_last_turn <- FALSE
5466

5567
# Restore previous chat session, if applicable
@@ -62,9 +74,15 @@ model_bot <- function(new_session = FALSE) {
6274
chat_append_message("chat", msg, chunk = FALSE)
6375
}
6476
restored_since_last_turn <- TRUE
77+
} else if (length(chat$get_turns()) > 0) {
78+
client_msgs <- shinychat::contents_shinychat(chat)
79+
if (length(client_msgs)) {
80+
for (msg in client_msgs) {
81+
chat_append_message("chat", msg, chunk = FALSE)
82+
}
83+
restored_since_last_turn <- TRUE
84+
}
6585
}
66-
67-
chat <- modelbot_client(default_turns = globals$turns)
6886
start_chat_request <- function(user_input) {
6987
prefix <- if (restored_since_last_turn) {
7088
paste0(
@@ -76,40 +94,49 @@ model_bot <- function(new_session = FALSE) {
7694
}
7795
restored_since_last_turn <<- FALSE
7896

97+
# Set up tool result callback to hide tool requests when complete
98+
clear_on_tool_result <- chat$on_tool_result(function(result) {
99+
session <- shiny::getDefaultReactiveDomain()
100+
if (is.null(session)) return()
101+
session$sendCustomMessage(
102+
"shinychat-hide-tool-request",
103+
result@request@id
104+
)
105+
})
106+
79107
stream <- save_stream_output()(
80-
chat$stream_async(paste0(prefix, user_input))
108+
chat$stream_async(paste0(prefix, user_input), stream = "content")
81109
)
82-
chat_append("chat", stream) |>
83-
promises::then(
84-
~ {
85-
if (session$isClosed()) {
86-
req(FALSE)
87-
}
88-
89-
# After each successful turn, save everything in case we need to
90-
# restore (i.e. user stops the app and restarts it)
91-
globals$turns <- chat$get_turns()
92-
save_messages(
93-
list(role = "user", content = user_input),
94-
list(role = "assistant", content = take_pending_output())
95-
)
96-
}
97-
) |>
98-
promises::finally(
99-
function(...) {
100-
tokens <- chat$get_tokens(include_system_prompt = FALSE)
101-
input <- sum(tokens$tokens[tokens$role == "user"])
102-
output <- sum(tokens$tokens[tokens$role == "assistant"])
103-
104-
cat("\n")
105-
cat(rule("Turn ", nrow(tokens) / 2), "\n", sep = "")
106-
cat("Total input tokens: ", input, "\n", sep = "")
107-
cat("Total output tokens: ", output, "\n", sep = "")
108-
cat("Total tokens: ", input + output, "\n", sep = "")
109-
cat("\n")
110-
.last_chat <<- chat
111-
}
110+
111+
p <- chat_append("chat", stream)
112+
p <- promises::then(p, function(x) {
113+
if (session$isClosed()) {
114+
req(FALSE)
115+
}
116+
117+
# After each successful turn, save everything in case we need to
118+
# restore (i.e. user stops the app and restarts it)
119+
globals$turns <- chat$get_turns()
120+
save_messages(
121+
list(role = "user", content = user_input),
122+
list(role = "assistant", content = take_pending_output())
112123
)
124+
})
125+
promises::finally(p, function(...) {
126+
clear_on_tool_result()
127+
128+
tokens <- chat$get_tokens(include_system_prompt = FALSE)
129+
input <- sum(tokens$tokens[tokens$role == "user"])
130+
output <- sum(tokens$tokens[tokens$role == "assistant"])
131+
132+
cat("\n")
133+
cat(rule("Turn ", nrow(tokens) / 2), "\n", sep = "")
134+
cat("Total input tokens: ", input, "\n", sep = "")
135+
cat("Total output tokens: ", output, "\n", sep = "")
136+
cat("Total tokens: ", input + output, "\n", sep = "")
137+
cat("\n")
138+
.last_chat <<- chat
139+
})
113140
}
114141

115142
observeEvent(input$chat_user_input, {
@@ -145,6 +172,9 @@ save_messages <- function(...) {
145172
}
146173

147174
save_output_chunk <- function(chunk) {
175+
if (S7::S7_inherits(chunk, ellmer::Content)) {
176+
chunk <- shinychat::contents_shinychat(chunk)
177+
}
148178
globals$pending_output$add(chunk)
149179
invisible()
150180
}

0 commit comments

Comments
 (0)