@@ -131,3 +131,231 @@ def tokenize(samples):
131131 "answer" ,
132132 "input_formatted" ,
133133 ])
134+
135+ class MMLU (Dataset ):
136+ """ MMLU Dataset """
137+ @classmethod
138+ def _format_question (cls , question : str , choices : list [str ]):
139+ return f"{ question .strip ()} \n A. { choices [0 ]} \n B. { choices [1 ]} \n C. { choices [2 ]} \n D. { choices [3 ]} \n Answer:"
140+
141+ @classmethod
142+ def _format_question_and_answer (cls , question : str , choices : list [str ], answer : str ):
143+ return cls ._format_question (question , choices ) + f" { answer } "
144+
145+ @classmethod
146+ def load_fewshot (cls , num_fewshot : int = 5 , fewshot_split : str = "dev" ):
147+ if num_fewshot == 0 :
148+ return {}
149+
150+ fewshot_split = load_dataset ("cais/mmlu" , name = "all" , split = fewshot_split )
151+ grouped_fewshot_questions = {}
152+
153+ def group_fewshot_questions (sample ):
154+ question = sample ["question" ]
155+ choices = sample ["choices" ]
156+ subject = sample ["subject" ]
157+ answer = chr (ord ("A" ) + sample ["answer" ])
158+
159+ if len (grouped_fewshot_questions .get (subject , [])) >= num_fewshot :
160+ return
161+
162+ if subject not in grouped_fewshot_questions :
163+ grouped_fewshot_questions [subject ] = []
164+
165+ grouped_fewshot_questions [subject ].append (cls ._format_question_and_answer (question , choices , answer ))
166+
167+ fewshot_split .map (group_fewshot_questions )
168+
169+ for subject , questions in grouped_fewshot_questions .items ():
170+ if len (questions ) < num_fewshot :
171+ raise ValueError (f"Not enough samples available in split { fewshot_split } to satisfy { num_fewshot } fewshot samples." )
172+
173+ def combine_questions (subject , questions ):
174+ formatted_subject = subject .replace ("_" , " " )
175+ formatted_string = f"The following are multiple choice questions (with answers) about { formatted_subject } .\n \n "
176+ for question in questions :
177+ formatted_string += question
178+ formatted_string += "\n \n "
179+ return formatted_string
180+
181+ formatted_fewshot_questions = {subject : combine_questions (subject , questions ) for subject , questions in grouped_fewshot_questions .items ()}
182+ return formatted_fewshot_questions
183+
184+
185+ @staticmethod
186+ def load_dataset (split : str = "test" ):
187+ if split != "test" :
188+ raise ValueError ("MMLU dataset only supports test split." )
189+ return load_dataset ("cais/mmlu" , name = "all" , split = split )
190+
191+ @classmethod
192+ def load_encoded_dataset (cls , tokenizer : PreTrainedTokenizer , context_length : int , split : str , num_fewshot : int = 5 , fewshot_split : str = "dev" ):
193+ dataset_split = cls .load_dataset (split )
194+ fewshot_subject_headers = cls .load_fewshot (num_fewshot , fewshot_split )
195+
196+ def tokenize (sample ):
197+ question = sample ["question" ]
198+ choices = sample ["choices" ]
199+ subject = sample ["subject" ]
200+
201+ formatted_question = list (map (lambda question , choices : cls ._format_question (question , choices ), question , choices ))
202+ fewshot_formatted_question = list (map (lambda subject , question : str (fewshot_subject_headers [subject ] + question ), subject , formatted_question )) if num_fewshot > 0 else formatted_question
203+
204+ tokenized_question = tokenizer (
205+ fewshot_formatted_question ,
206+ return_token_type_ids = False ,
207+ add_special_tokens = True ,
208+ )
209+
210+ tokenized_question = {
211+ k : list (map (lambda field : field [- context_length :], v ))
212+ for k , v in tokenized_question .items ()
213+ }
214+
215+ tokenized_answer = tokenizer (
216+ list (map (lambda answer : chr (ord ("A" ) + answer ), sample ["answer" ])),
217+ return_token_type_ids = False ,
218+ add_special_tokens = False ,
219+ return_tensors = "pt" ,
220+ )
221+
222+ result = tokenized_question
223+ result .update ({"label" : tokenized_answer ["input_ids" ]})
224+
225+ return result
226+
227+ return dataset_split .map (
228+ tokenize ,
229+ batched = True ,
230+ remove_columns = [
231+ "question" ,
232+ "subject" ,
233+ "choices" ,
234+ "answer" ,
235+ ])
236+
237+
238+ class MMMLU (Dataset ):
239+ """ MMLU Dataset """
240+ @classmethod
241+ def _format_question (cls , question : str , choices : tuple [str ]):
242+ return f"{ question .strip ()} \n A. { choices [0 ]} \n B. { choices [1 ]} \n C. { choices [2 ]} \n D. { choices [3 ]} \n Answer:"
243+
244+ @classmethod
245+ def _format_question_and_answer (cls , question : str , choices : list [str ], answer : str ):
246+ return cls ._format_question (question , choices ) + f" { answer } "
247+
248+ @classmethod
249+ def load_fewshot (cls , dataset_split , num_fewshot : int = 5 ):
250+ if num_fewshot == 0 :
251+ return {}
252+
253+ grouped_fewshot_questions : dict [str , list [str ]] = {}
254+
255+ def group_fewshot_questions (sample : dict [str , str ]):
256+ question = sample ["Question" ]
257+ choices = (sample ["A" ], sample ["B" ], sample ["C" ], sample ["D" ])
258+ subject = sample ["Subject" ]
259+ answer = sample ["Answer" ]
260+
261+ # We need one extra question to make sure that we can create an appropriately formatted string even if one
262+ # of the fewshot questions is encountered.
263+ if len (grouped_fewshot_questions .get (subject , [])) >= num_fewshot + 1 :
264+ return
265+
266+ if subject not in grouped_fewshot_questions :
267+ grouped_fewshot_questions [subject ] = []
268+
269+ grouped_fewshot_questions [subject ].append (cls ._format_question_and_answer (question , choices , answer ))
270+
271+ dataset_split .map (group_fewshot_questions )
272+
273+ for subject , questions in grouped_fewshot_questions .items ():
274+ if len (questions ) < num_fewshot :
275+ raise ValueError (
276+ f"Not enough samples available in split to satisfy { num_fewshot } fewshot samples."
277+ )
278+
279+ return grouped_fewshot_questions
280+
281+
282+ @staticmethod
283+ def load_dataset (split : str = "default" ):
284+ return load_dataset ("openai/MMMLU" , name = split , split = "test" )
285+
286+ @classmethod
287+ def load_encoded_dataset (cls , tokenizer : PreTrainedTokenizer , context_length : int , split : str , num_fewshot : int = 5 ):
288+ dataset_split = cls .load_dataset (split )
289+ grouped_fewshot_questions = cls .load_fewshot (dataset_split , num_fewshot )
290+
291+ def tokenize (sample : dict [str , list [str ]]):
292+ question = sample ["Question" ]
293+ A = sample ["A" ]
294+ B = sample ["B" ]
295+ C = sample ["C" ]
296+ D = sample ["D" ]
297+ subject = sample ["Subject" ]
298+
299+ formatted_question = list (
300+ map (lambda question , A , B , C , D : cls ._format_question (question , (A , B , C , D )), question , A , B , C , D )
301+ )
302+
303+ def assemble_fewshot_question (formatted_question : str , subject : str ):
304+ subject_fewshot_questions = grouped_fewshot_questions [subject ]
305+
306+ formatted_string = ""
307+ num_fewshot_questions_added = 0
308+ for fewshot_question in subject_fewshot_questions :
309+ if num_fewshot_questions_added >= num_fewshot :
310+ break
311+ if formatted_question in fewshot_question :
312+ continue
313+
314+ formatted_string += fewshot_question
315+ formatted_string += "\n \n "
316+ num_fewshot_questions_added += 1
317+
318+ formatted_string += formatted_question
319+ return formatted_string
320+
321+ fewshot_formatted_question = list (
322+ map (assemble_fewshot_question , formatted_question , subject )
323+ )
324+
325+ tokenized_question = tokenizer (
326+ fewshot_formatted_question ,
327+ return_token_type_ids = False ,
328+ add_special_tokens = True ,
329+ )
330+
331+ tokenized_question = {
332+ k : list (map (lambda field : field [- context_length :], v ))
333+ for k , v in tokenized_question .items ()
334+ }
335+
336+ tokenized_answer = tokenizer (
337+ sample ["Answer" ],
338+ return_token_type_ids = False ,
339+ add_special_tokens = False ,
340+ return_tensors = "pt" ,
341+ )
342+
343+ result = tokenized_question
344+ result .update ({"label" : tokenized_answer ["input_ids" ]})
345+
346+ return result
347+
348+ return dataset_split .map (
349+ tokenize ,
350+ batched = True ,
351+ remove_columns = [
352+ "Question" ,
353+ "A" ,
354+ "B" ,
355+ "C" ,
356+ "D" ,
357+ "Answer" ,
358+ "Subject" ,
359+ "Unnamed: 0" ,
360+ ],
361+ )
0 commit comments