on-off-line

对于在线处理和离线处理的讨论

问题描述:

​ 我在使用 MLLM 的 VQA 方式进行单目标跟踪的时候,出现了一个问题是,由于输入的图片分辨率不一样,如果我使用了 normalize bbox 模型并不能正确理解 normalize bbox 的含义,因此我需要正确裁剪 bbox 输入模型相同分辨率的图像,因此需要正确处理数据集,有两种方式,将数据集先全部裁剪完再开始训练,这就是 off-line 处理;另一种就是边处理数据边裁剪,每遍历到数据集中的一个 sample 就裁剪一个 sample 内的图片。

结论 TLDR 版:

  • online 处理一般而言省磁盘容量,但是会有更大的计算开销,offline 处理一般省计算开销(提前处理的计算量不计入训练过程),但是对磁盘容量的消耗更大

  • 对于科研这种需要快速迭代的工作来说,第一考虑的就是时间成本,一般来说 online 和 offline 的时间差别不大,因此写代码时间消耗反而是最大的(这个看情况),那么就是代码怎么好改就选取哪个方法

  • online 与 offline 不止限制于数据处理上,还在各个方面会出现,例如推荐系统的 online training 和 offline training 的区别;模型推理的 online inference 和 offline inference 的区别等等

online 实现难度分析:

​ 因为 MLLM 模型微调全部是在 llama factory 框架下进行微调的,但是 llama factory 并没有对数据集有很多拓展功能,下面就需要分析源代码看一下数据集的处理逻辑,评估一下在这个框架下是否真的好改代码,我们需要明确目标:对图片数据进行裁剪的最佳位置在调用加载图片进入内存之后,在对图片进行归一化等数值化处理之前。下面我们一步步查看源码:

llama factory 数据集处理流程链:

我的数据集一个 sample 示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
{
"messages": [
{
"role": "user",
"content": "The first 2 images (<image><image>) show the object of interest: 'the player who controls the ball'. Please locate this object in the final image (<image>). Provide its bounding box as [x1, y1, x2, y2] coordinates within that image."
},
{
"role": "assistant",
"content": "The object 'the player who controls the ball' is located at [38, 248, 103, 313]."
}
],
"images": [
"cropped_dataset_visible/cropped_images/sample_000000/template_000.jpg",
"cropped_dataset_visible/cropped_images/sample_000000/template_001.jpg",
"cropped_dataset_visible/cropped_images/sample_000000/search.jpg"
]
},

加载数据进内存:

  • 训练流程开始时,会解析命令行参数和配置文件,包括数据集名称和路径 (data/dataset_info.json 用于定义数据集元信息,如列名、格式等)

  • get_dataset 函数负责根据数据集名称和配置加载原始数据集。它会调用 _load_single_dataset (src/llamafactory/data/loader.py) 从 Hugging Face Hub、ModelScope、本地文件等来源加载数据。

load_dataset

自定义选择数据处理器:

  • _get_dataset_processor 函数根据当前的训练阶段 (stage,如 “sft”, “pt”, “rm”) 选择一个合适的 DatasetProcessor 子类

image-20250425202855356

其中各种 dataprocessor 保存在 src/llamafactory/data/processor 路径下:

processors

​ 因此每次 iter 会调用不同的 processor 对数据进行预处理,因此我们需要进入 processor 内容进一步查看代码,我们使用的是 SFT 训练,因此就用 SupervisedDatasetProcessor 为例,它的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@dataclass
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: list[dict[str, str]],
response: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
)
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
if self.data_args.mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns

for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= self.data_args.cutoff_len:
break

source_len, target_len = infer_seqlen(
len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length
)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len

if self.data_args.train_on_prompt:
source_label = source_ids
elif self.template.efficient_eos:
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len

if self.data_args.mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids

if self.data_args.mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label

if self.template.efficient_eos:
input_ids += [self.tokenizer.eos_token_id]
labels += [self.tokenizer.eos_token_id]

return input_ids, labels

def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue

input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])

return model_inputs

def print_data_example(self, example: dict[str, list[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")

  • _encode_data_example 这个私有方法是整个处理流程的核心,它负责处理单个数据样本(一个完整的对话或多轮对话),将其转换为模型训练所需的 input_idslabels 序列
  • preprocess_dataset 这是公开的预处理方法,接收一个包含多个数据样本的批次,遍历这些样本,并调用 _encode_data_example 来处理每一个样本。它将处理结果组织成一个字典,准备用于 PyTorch 的 DataLoader

一步步调试查看中间过程可以发现,