Toolformer: Language Models Can Teach Themselves to Use Tools
- Thesis: Train a model that learns in a self-supervised setting which API to call, when to call the API, how to pass the parameters, and how to incorporate the API results within the body of the text.
- Methods:
- Dataset(s):
- To avoid changing the vocabulary of the LM:
[
token refers<API>
]
token refers</API>
->
token refers to all tokens after it as result of API call
- Use few-shot examples for each tool and ask the LM to generate examples
- Sample positions \(i\bel{1,...,n}\) by selecting the top-k positions that have probabilities \(p > threshold\)
- For each position, generate
m
API calls - Execute each API call and embed the results in the text
- Filter API calls using weighted cross entropy. The text with the API call with its results should have (intuitively) lower cross entropy than the text with no API call or an API call with its input but without its results. Therefore, we calculate the weighted cross entropy for 1) empty string, 2) API call with its input but without its results, and 3) API call with both input and output
- If the API call with its results is useful for the generation of the future tokens, then the difference between the weighted cross entropy of min(weighted cross entropy of both no API call and API call with input only) and API call with results \(>=\) filtering threshold (1.0 in the paper)
- For all positions that remained after sampling and filtering, we augment every text sequence with API calls at all positions remained. For example, if a sequence has 5 positions, we embed the API details and the response at each position for all API calls
- To avoid changing the vocabulary of the LM:
- Model(s):
- GPT-J is the LM used for fine-tuning
- Training:
- Fine-tuning using the augmented dataset and language modeling objective
- Greedy decoding is used. If
<API>
token is one of the top 10 tokens at any position, API call will be made at that position- Only one API call is allowed for any input sequence to avoid getting stuck in a loop of API calls
- Inference:
- Typical decoding is used until the generated token is
->
where the API call is executed and the response is embedded and the generation continues
- Typical decoding is used until the generated token is
- Dataset(s):
- Contribution: Toolformer model that still generalizes to natural language tasks and knows when/how to call API calls.
- Limitation:
- Fine-tuned LM doesn’t know how to perform chain of API calls
- LM is sensitive to the wording of inputs
- Sample inefficient
- Model is not capable of interactively working with tool
- Notes:
- Use LM to generate a synthetic annotated data for API calls using a few examples for each API
- Toolformer must not loose its generality in performing other language-related tasks
</API>
token inserted and the generation continues - Tools expored:
- QA system: Use LM that was trained on natural language dataset
- Wikipedia: Use BM-25 on top of Wikipedia search dump
- Calendar: Only return the current date
- Calculator: Supports only 4 arithmetic ops (+, -, *, /)
- Machine translation: Translate from any language to English. The source language is detected using fasttext classifier
- Bigger models make better use of the APIs and benefit from them
- Smaller models, below 775M, don’t see a benefit from using APIs
#nlp #llm #agents