How to Fine-Tune a Llama-2 (7B) to Beat ChatGPT
Large language models (LLMs) are becoming increasingly powerful, and they are being used for a variety of tasks, including chatbots, text generation, and question-answering. However, LLMs can be expensive and resource-intensive to train. In this blog post, we will show you how to fine-tune a smaller LLM (7B) to perform better than ChatGPT.
Fine-tuning is a process of training an LLM on a specific dataset to improve its performance on a particular task. In this case, we will fine-tune a 7B LLM to multiply two numbers.
So let us start with the performance of different large language models on a simple multiplication task: 458*987 = 452046.
At first we will look at Meta's recently released model Llama-2-7B. We deployed Lllama-2 on TrueFoundry and tried it out with TrueFoundry's Langchain integrations. Here is the result for the same.
As we can clearly see that this doesn't perform well at the task. (which is also expected given the size of the model). Let us see how the state of the art models perform at the same task:
Let us look at the result from ChatGPT ( GPT3.5 Turbo):
Now the answer (451,086) is very close to the actual answer which is 452,046 but the answer is not correct.
Let us try with a different prompt (to do a step by step calculation and see what is does):
But again, it arrives to an incorrect result: 450,606 🤨
Finally let us try the State of the Art (GPT-4) and check how it performs on the task:
It is very close to the correct answer (452,046) and might look correct to anyone. But clearly, the answer is incorrect.
Why are LLMs performing poorly on simple multiplication tasks?
The answer to this question is very simple. They are not "trained" for it.
LLMs are trained on massive datasets of text and code, but this data is not always structured in a way that is helpful for mathematical reasoning. For example, the data may not include explicit representations of mathematical concepts, such as addition, subtraction, multiplication, and division. This can make it difficult for LLMs to learn how to perform these operations correctly.
So we thought if it is possible to train an LLM to perform mathematical operations? We did an experiment and the results were quite promising!
Finetuning Llama-2 7B to learn multiplication
Curating a Dataset
To start with it, we curated a dataset of 40,000 rows of prompts and instructions. We tried to mimic the exact steps followed in the long multiplication method, which involves breaking the multiplicand and multiplier into their individual digits, then multiplying each digit of the multiplicand with each digit of the multiplier and placing the results in the appropriate positions based on their place values. Finally, all the individual products are added together to get the final result.
Here is a sample prompt:
You are a helpful assistant providing detailed answers to multiplication questions. Ensure you provide a thorough explanation using the long multiplication method for the math problem given below.
If you don't know the answer to a question, please don't share false information. What is 5244 * 1932?
And here is a sample response:
Here are the steps for multiplying 5244 * 1932 using the long multiplication method.
Step 1: Break 5244 into its individual digits: 5 (thousands place), 2 (hundreds place), 4 (tens place), and 4 (ones place).
Step 2: Similarly, break 1932 into its individual digits: 1 (thousands place), 9 (hundreds place), 3 (tens place), and 2 (ones place).
- 4 (from 1 place of first number) * 2 (from 1 place of second number) = 8 -> Shift it 0 places to the left (because we are multiplying with the 1 place of the second number) to get 8
- 4 (from 1 place of first number) * 3 (from 10 place of second number) = 12 -> Shift it 1 places to the left (because we are multiplying with the 10 place of the second number) to get 120
- 4 (from 1 place of first number) * 9 (from 100 place of second number) = 36 -> Shift it 2 places to the left (because we are multiplying with the 100 place of the second number) to get 3600
- 4 (from 1 place of first number) * 1 (from 1000 place of second number) = 4 -> Shift it 3 places to the left (because we are multiplying with the 1000 place of the second number) to get 4000
- 4 (from 10 place of first number) * 2 (from 1 place of second number) = 8 -> Shift it 1 places to the left (because we are multiplying with the 1 place of the second number) to get 80
- 4 (from 10 place of first number) * 3 (from 10 place of second number) = 12 -> Shift it 2 places to the left (because we are multiplying with the 10 place of the second number) to get 1200
- 4 (from 10 place of first number) * 9 (from 100 place of second number) = 36 -> Shift it 3 places to the left (because we are multiplying with the 100 place of the second number) to get 36000
- 4 (from 10 place of first number) * 1 (from 1000 place of second number) = 4 -> Shift it 4 places to the left (because we are multiplying with the 1000 place of the second number) to get 40000
- 2 (from 100 place of first number) * 2 (from 1 place of second number) = 4 -> Shift it 2 places to the left (because we are multiplying with the 1 place of the second number) to get 400
- 2 (from 100 place of first number) * 3 (from 10 place of second number) = 6 -> Shift it 3 places to the left (because we are multiplying with the 10 place of the second number) to get 6000
- 2 (from 100 place of first number) * 9 (from 100 place of second number) = 18 -> Shift it 4 places to the left (because we are multiplying with the 100 place of the second number) to get 180000
- 2 (from 100 place of first number) * 1 (from 1000 place of second number) = 2 -> Shift it 5 places to the left (because we are multiplying with the 1000 place of the second number) to get 200000
- 5 (from 1000 place of first number) * 2 (from 1 place of second number) = 10 -> Shift it 3 places to the left (because we are multiplying with the 1 place of the second number) to get 10000
- 5 (from 1000 place of first number) * 3 (from 10 place of second number) = 15 -> Shift it 4 places to the left (because we are multiplying with the 10 place of the second number) to get 150000
- 5 (from 1000 place of first number) * 9 (from 100 place of second number) = 45 -> Shift it 5 places to the left (because we are multiplying with the 100 place of the second number) to get 4500000
- 5 (from 1000 place of first number) * 1 (from 1000 place of second number) = 5 -> Shift it 6 places to the left (because we are multiplying with the 1000 place of the second number) to get 5000000
Now, add all these results together:
The sum of 8 + 120 + 3600 + 4000 + 80 + 1200 + 36000 + 40000 + 400 + 6000 + 180000 + 200000 + 10000 + 150000 + 4500000 + 5000000 = 10131408.
Then we wrote a Python script that generates prompt-response pairs by randomly selecting two digits for multiplication. We curated a dataset with 40,000 rows.
Finetuning the Model
Now, once the dataset is ready we need to finetune the model.
We are using the Meta's - finetuned chat variant (7 Billion parameters) of Llama-2 as the base model.
We performed the finetuning using QLora finetuning using BitsAndBytes and Peft library. Here is the lora config we used :
LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
)
It took around 8 hours to train on an A100 40GB GPU machine for a dataset of 40,000 rows.
Finally, we deployed the finetuned model on TrueFoundry again and here are the results:
So finally!! We can see that the finetuned model is able to calculate the result correctly.
Conclusion
Although arithmetic is not a task for which we will use LLM, but this example demonstrates how a "small" LLM (7B parameters) finetuned properly for a specific task can out-perform the "large" LLMs ( like GPT3.5 turbo - 175B parameters and GPT-4) on a specific task.
The smaller models finetuned models are cheap in inference, better at the specialized tasks, and can be deployed easily on your on cloud!
We wrote a detailed blog on fine-tuning Llama 2