|
16 | 16 | Optional,
|
17 | 17 | Sequence,
|
18 | 18 | TypeVar,
|
| 19 | + overload, |
19 | 20 | )
|
20 | 21 |
|
21 | 22 | from pydantic import BaseModel
|
@@ -176,17 +177,122 @@ def system_prompt(self, value: str | None):
|
176 | 177 | if value is not None:
|
177 | 178 | self._turns.insert(0, Turn("system", value))
|
178 | 179 |
|
179 |
| - def tokens(self) -> list[tuple[int, int] | None]: |
| 180 | + @overload |
| 181 | + def tokens(self) -> list[tuple[int, int] | None]: ... |
| 182 | + |
| 183 | + @overload |
| 184 | + def tokens( |
| 185 | + self, |
| 186 | + values: Literal["cumulative"], |
| 187 | + ) -> list[tuple[int, int] | None]: ... |
| 188 | + |
| 189 | + @overload |
| 190 | + def tokens( |
| 191 | + self, |
| 192 | + values: Literal["discrete"], |
| 193 | + ) -> list[int]: ... |
| 194 | + |
| 195 | + def tokens( |
| 196 | + self, |
| 197 | + values: Literal["cumulative", "discrete"] = "discrete", |
| 198 | + ) -> list[int] | list[tuple[int, int] | None]: |
180 | 199 | """
|
181 | 200 | Get the tokens for each turn in the chat.
|
182 | 201 |
|
| 202 | + Parameters |
| 203 | + ---------- |
| 204 | + values |
| 205 | + If "cumulative" (the default), the result can be summed to get the |
| 206 | + chat's overall token usage (helpful for computing overall cost of |
| 207 | + the chat). If "discrete", the result can be summed to get the number of |
| 208 | + tokens the turns will cost to generate the next response (helpful |
| 209 | + for estimating cost of the next response, or for determining if you |
| 210 | + are about to exceed the token limit). |
| 211 | +
|
183 | 212 | Returns
|
184 | 213 | -------
|
185 |
| - list[tuple[int, int] | None] |
186 |
| - A list of tuples, where each tuple contains the start and end token |
187 |
| - indices for a turn. |
| 214 | + list[int] |
| 215 | + A list of token counts for each (non-system) turn in the chat. The |
| 216 | + 1st turn includes the tokens count for the system prompt (if any). |
| 217 | +
|
| 218 | + Raises |
| 219 | + ------ |
| 220 | + ValueError |
| 221 | + If the chat's turns (i.e., `.get_turns()`) are not in an expected |
| 222 | + format. This may happen if the chat history is manually set (i.e., |
| 223 | + `.set_turns()`). In this case, you can inspect the "raw" token |
| 224 | + values via the `.get_turns()` method (each turn has a `.tokens` |
| 225 | + attribute). |
188 | 226 | """
|
189 |
| - return [turn.tokens for turn in self._turns] |
| 227 | + |
| 228 | + turns = self.get_turns(include_system_prompt=False) |
| 229 | + |
| 230 | + if values == "cumulative": |
| 231 | + return [turn.tokens for turn in turns] |
| 232 | + |
| 233 | + if len(turns) == 0: |
| 234 | + return [] |
| 235 | + |
| 236 | + err_info = ( |
| 237 | + "This can happen if the chat history is manually set (i.e., `.set_turns()`). " |
| 238 | + "Consider getting the 'raw' token values via the `.get_turns()` method " |
| 239 | + "(each turn has a `.tokens` attribute)." |
| 240 | + ) |
| 241 | + |
| 242 | + # Sanity checks for the assumptions made to figure out user token counts |
| 243 | + if len(turns) == 1: |
| 244 | + raise ValueError( |
| 245 | + "Expected at least two turns in the chat history. " + err_info |
| 246 | + ) |
| 247 | + |
| 248 | + if len(turns) % 2 != 0: |
| 249 | + raise ValueError( |
| 250 | + "Expected an even number of turns in the chat history. " + err_info |
| 251 | + ) |
| 252 | + |
| 253 | + if turns[0].role != "user": |
| 254 | + raise ValueError( |
| 255 | + "Expected the 1st non-system turn to have role='user'. " + err_info |
| 256 | + ) |
| 257 | + |
| 258 | + if turns[1].role != "assistant": |
| 259 | + raise ValueError( |
| 260 | + "Expected the 2nd turn non-system to have role='assistant'. " + err_info |
| 261 | + ) |
| 262 | + |
| 263 | + if turns[1].tokens is None: |
| 264 | + raise ValueError( |
| 265 | + "Expected the 1st assistant turn to contain token counts. " + err_info |
| 266 | + ) |
| 267 | + |
| 268 | + res: list[int] = [ |
| 269 | + # Implied token count for the 1st user input |
| 270 | + turns[1].tokens[0], |
| 271 | + # The token count for the 1st assistant response |
| 272 | + turns[1].tokens[1], |
| 273 | + ] |
| 274 | + for i in range(1, len(turns) - 1, 2): |
| 275 | + ti = turns[i] |
| 276 | + tj = turns[i + 2] |
| 277 | + if ti.role != "assistant" or tj.role != "assistant": |
| 278 | + raise ValueError( |
| 279 | + "Expected even turns to have role='assistant'." + err_info |
| 280 | + ) |
| 281 | + if ti.tokens is None or tj.tokens is None: |
| 282 | + raise ValueError( |
| 283 | + "Expected role='assistant' turns to contain token counts." |
| 284 | + + err_info |
| 285 | + ) |
| 286 | + res.extend( |
| 287 | + [ |
| 288 | + # Implied token count for the user input |
| 289 | + tj.tokens[0] - sum(ti.tokens), |
| 290 | + # The token count for the assistant response |
| 291 | + tj.tokens[1], |
| 292 | + ] |
| 293 | + ) |
| 294 | + |
| 295 | + return res |
190 | 296 |
|
191 | 297 | def app(
|
192 | 298 | self,
|
|
0 commit comments