1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
| @dataclass class SupervisedDatasetProcessor(DatasetProcessor): def _encode_data_example( self, prompt: list[dict[str, str]], response: list[dict[str, str]], system: Optional[str], tools: Optional[str], images: list["ImageInput"], videos: list["VideoInput"], audios: list["AudioInput"], ) -> tuple[list[int], list[int]]: messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) input_ids, labels = self.template.mm_plugin.process_token_ids( [], [], images, videos, audios, self.tokenizer, self.processor ) encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools) total_length = len(input_ids) + (1 if self.template.efficient_eos else 0) if self.data_args.mask_history: encoded_pairs = encoded_pairs[::-1]
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): if total_length >= self.data_args.cutoff_len: break
source_len, target_len = infer_seqlen( len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length ) source_ids = source_ids[:source_len] target_ids = target_ids[:target_len] total_length += source_len + target_len
if self.data_args.train_on_prompt: source_label = source_ids elif self.template.efficient_eos: source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) else: source_label = [IGNORE_INDEX] * source_len
if self.data_args.mask_history and turn_idx != 0: target_label = [IGNORE_INDEX] * target_len else: target_label = target_ids
if self.data_args.mask_history: input_ids = source_ids + target_ids + input_ids labels = source_label + target_label + labels else: input_ids += source_ids + target_ids labels += source_label + target_label
if self.template.efficient_eos: input_ids += [self.tokenizer.eos_token_id] labels += [self.tokenizer.eos_token_id]
return input_ids, labels
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]: model_inputs = defaultdict(list) for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: logger.warning_rank0( "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) ) continue
input_ids, labels = self._encode_data_example( prompt=examples["_prompt"][i], response=examples["_response"][i], system=examples["_system"][i], tools=examples["_tools"][i], images=examples["_images"][i] or [], videos=examples["_videos"][i] or [], audios=examples["_audios"][i] or [], ) model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) model_inputs["images"].append(examples["_images"][i]) model_inputs["videos"].append(examples["_videos"][i]) model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: dict[str, list[int]]) -> None: valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|