Fine-Tuning LLMs Using Your Data
Fine-tuning LLMs using your data refers to the process of adjusting the parameters of a pre-trained Large Language Model (LLM) to better fit a specific dataset for your applications. This can help improve the model’s performance on tasks specific to a particular domain or use case.
LLMs are machine learning models that are very effective at performing language-related tasks such as translation, answering questions, chat and content summarization, as well as content and code generation. They can be fine-tuned using a small amount of task-specific data, enabling them to perform better on that task with limited examples.
For example, imagine you run an e-commerce site selling camera products and you want to condense all reviews for a product into one summary for customers. You could fine-tune an LLM like T5 or Vicuna on your dataset of product reviews to specialize a pre-trained LLM for this task.
Fine-Tuning vs. Prompt Engineering
Using LLMs with context data and a prompt refers to the process of providing enough context, instruction and examples to the model at inference time to get it to do what you want without changing the underlying weights. This is known as prompt engineering.
Fine-tuning, on the other hand, involves directly updating the model parameters using a dataset that captures the distribution of tasks you want it to accomplish. This allows the model to specialize in particular use cases and improves its performance in specific domains.
In other words, prompt engineering is about manipulating the input prompt to coax the model into a region of its latent space such that the probability distribution of the next-tokens that it predicts matches your intent. Fine-tuning is about adjusting the model’s parameters to better fit a specific dataset.
The Process of Fine-Tuning
For my experiment I used the FastChat project and documentation.
In practice, you will follow fine-tuning directions for a specific pre-trained LLM that is usually provided with public models from Hugging Face and other sources. We will look at the general steps taken in fine-tuning and then later look at just one example.
Here are the steps typically involved in fine-tuning an LLM:
- Prepare your dataset: to fine-tune the LLM, you’ll need a dataset that aligns with your target domain or task. Ensure your dataset is large enough to cover the variations in your domain or task.
- Configure the training parameters: fine-tuning involves adjusting the LLM’s weights based on the custom dataset. This step involves configuring the training parameters such as learning rate, batch size, and number of epochs.
- Set up the training environment: set up the hardware and software environment for training, such as selecting the appropriate GPU and installing necessary libraries.
- Fine-tune the model on your custom dataset.
- Evaluate the fine-tuned model: evaluate the performance of the fine-tuned model on a separate validation dataset to ensure that it has improved on the target task.
- Save and use the fine-tuned model so you can later use it for inference on new data
Building a Chat Application Using Text From My Books and Fine-Tuning
It can be challenging preparing data for fine tuning and expensive to fine tune existing LLMs. Here we look at an experiment I performed using manuscript files for several of my books. I “chunked” the entirety of this text and used the OpenAI GPT-3.5 API to generate fine tuning questions. My understanding of the OpenAI API terms and conditions is that you can only do this for non-commercial research.
I wrote the following Python script to dunk my manuscript files, use an OpenAI API, and write the prompts to one large JSON file:
1 from langchain.text_splitter import MarkdownTextSplitter
2 import openai
3 import os
4 import time
5 from pprint import pprint
6 openai.api_key = os.getenv("OPENAI_API_KEY")
7
8 def completion(s):
9 return openai.ChatCompletion.create(model="gpt-3.5-turbo",
10 messages=[{"role": "user",
11 "content": s}])
12
13 def gen_question(text):
14 time.sleep(0.5)
15 try:
16 q = completion(text +
17 f"Generate a good question for the following text:\n{text}")
18 #print(q)
19 return q.choices[0]['message']['content']
20 except Exception as e:
21 print("error:", e)
22 return ""
23
24 directories = ["../../lovinglisp-book/manuscript/",
25 "../../LangChain-book/manuscript/",
26 "../../Java-AI-Book/manuscript",
27 "../../haskell_book/manuscript/"
28 ]
29
30 result = ''
31
32 for directory in directories:
33 for filename in os.listdir(directory):
34 if filename.endswith('.md'):
35 print("processing filename:", filename)
36 with open(os.path.join(directory, filename), 'r') as f:
37 result += f.read()
38 print("length of result=", len(result))
39
40 splitter = MarkdownTextSplitter(chunk_size=200, chunk_overlap=10)
41 docs = splitter.create_documents([result])
42
43 chunks = map(lambda doc: doc.page_content, docs)
44
45 outFile = open("watson_book_prompts_edited.json", "w")
46 counter = 0
47 template = """
48 {
49 "id": "identity_COUNTERZZZ",
50 "conversations": [
51 {
52 "from": "human",
53 "value": "QUESTION"
54 },
55 {
56 "from": "gpt",
57 "value": "ANSWER"
58 }
59 ]
60 },
61
62 """
63
64 outFile.write("[\n")
65
66 for answer in chunks:
67 if "http" in answer or "{" in answer or "[" in answer or "~" in answer or "**" in answer or len(answer) < 100:
68 continue
69 question = gen_question(answer).replace('"', '').replace('\n', '')
70 print(question)
71 ans = answer.replace('"', '').replace('\n', '').replace('#', '')
72 s = template.replace("QUESTION", question).replace("ANSWER", ans).replace("COUNTERZZZ", str(counter))
73 #print(s)
74 if (len(s)>20):
75 outFile.write(s)
76 time.sleep(0.5)
77 counter += 1
78
79 # remove trailing comma to avoid JSON parsing error:
80 outFile.seek(outFile.tell() - 3, os.SEEK_SET)
81 outFile.write("]\n")
82 outFile.close()
Here is a small part of the generated JSON file:
1 {
2 "id": "identity_6",
3 "conversations": [
4 {
5 "from": "human",
6 "value": "What are the key features of CLML that make it suitable for machine learning, time series data analysis, and matrix and tensor operations?"
7 },
8 {
9 "from": "gpt",
10 "value": "Even though the learning curve is a bit steep, CLML provides a lot of functionality for machine learning, dealing with time series data, and general matrix and tensor operations."
11 }
12 ]
13 },
I ran this experiment on a GPU VPS from Lambda Labs. I slightly modified the local environment:
1 pip install fschat transformers openai
2 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
3 export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:1024'
I used the following command (derived from the FastChat documentation examples):
1 python -m torch.distributed.launch fastchat/train/train_flant5.py \
2 --model_name_or_path google/flan-t5-small \
3 --data_path watson_book_prompts_edited.json \
4 --bf16 True \
5 --output_dir ./checkpoints_flant5_3b \
6 --num_train_epochs 1 \
7 --per_device_train_batch_size 1 \
8 --per_device_eval_batch_size 1 \
9 --gradient_accumulation_steps 4 \
10 --evaluation_strategy "no" \
11 --save_strategy "steps" \
12 --save_steps 300 \
13 --save_total_limit 1 \
14 --learning_rate 2e-5 \
15 --weight_decay 0. \
16 --warmup_ratio 0.03 \
17 --lr_scheduler_type "cosine" \
18 --logging_steps 1 \
19 --fsdp "full_shard auto_wrap" \
20 --fsdp_transformer_layer_cls_to_wrap T5Block \
21 --tf32 True \
22 --model_max_length 2048 \
23 --preprocessed_path ./preprocessed_data/processed.json \
24 --gradient_checkpointing True
My experiment fine tuning an existing model using prompts produced from my books is a work in progress. I will update this chapter as I get better results.