Getting machines to perform reasoning tasks has long been a cherished goal of AI. These problems include examples such as word problems in mathematics and analytical commonsense reasoning (the kind that you typically see in standardized tests such as SAT/GRE etc.). Today’s large language models (LLMs) can perform many simple reasoning tasks out of the box. There is a growing field of research into how we might improve reasoning in LLMs over complex tasks.
One of the more complex ways of improving LLMs to do better reasoning is via the use of reinforcement learning given the large performance improvements that we have seen by the use of Reinforcement Learning from Human Feedback (RLHF). All of these model enhancements require considerable effort to collect large amounts of training data and choosing right algorithms for reinforcement learning. There is a good discussion on this topic in a paper by Alex Havrilla and team: “Teaching Large Language Models to Reason with Reinforcement Learning“.
I’ll dig into the details into Havrilla’s paper at a later time, but today, I’d like to focus on an easier technique that many people can try out directly. Jason Wei and a team at Google describe a method that can use prompt engineering to get LLMs to perform better reasoning: “Chain-of-Thought Prompting Elicits Reasoning in Large Language Models“
The intuition behind the paper is pretty straightforward. We humans use step-by-step thinking when trying to solve complex reasoning tasks. Maybe, the models will also get better at reasoning if they are trained on this “step-by-step thinking” process. The team behind the paper experimented with giving LLMs examples of this process and used few-shot learning to get the LLMs to perform better.
A simple example of a prompt used in the few-shot training is:
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 – 15 = 6. The answer is 6.
In the example above the bolded part of the answer is the chain of thought reasoning used to train the model on how to arrive at the correct answer.
This type of prompting could be useful in many complex reasoning natural language tasks. If you have an application where such reasoning could help, this simple change to prompting strategy could help improve your application’s performance.