generate_openelm.py
7.3 KB · 241 lines · python Raw
1 #
2 # For licensing see accompanying LICENSE file.
3 # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 #
5
6 """Module to generate OpenELM output given a model and an input prompt."""
7 import os
8 import logging
9 import time
10 import argparse
11 from typing import Optional, Union
12 import torch
13
14 from transformers import AutoTokenizer, AutoModelForCausalLM
15
16
17 def generate(
18 prompt: str,
19 model: Union[str, AutoModelForCausalLM],
20 hf_access_token: str = None,
21 tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf',
22 device: Optional[str] = None,
23 max_length: int = 1024,
24 assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None,
25 generate_kwargs: Optional[dict] = None,
26 ) -> str:
27 """ Generates output given a prompt.
28
29 Args:
30 prompt: The string prompt.
31 model: The LLM Model. If a string is passed, it should be the path to
32 the hf converted checkpoint.
33 hf_access_token: Hugging face access token.
34 tokenizer: Tokenizer instance. If model is set as a string path,
35 the tokenizer will be loaded from the checkpoint.
36 device: String representation of device to run the model on. If None
37 and cuda available it would be set to cuda:0 else cpu.
38 max_length: Maximum length of tokens, input prompt + generated tokens.
39 assistant_model: If set, this model will be used for
40 speculative generation. If a string is passed, it should be the
41 path to the hf converted checkpoint.
42 generate_kwargs: Extra kwargs passed to the hf generate function.
43
44 Returns:
45 output_text: output generated as a string.
46 generation_time: generation time in seconds.
47
48 Raises:
49 ValueError: If device is set to CUDA but no CUDA device is detected.
50 ValueError: If tokenizer is not set.
51 ValueError: If hf_access_token is not specified.
52 """
53 if not device:
54 if torch.cuda.is_available() and torch.cuda.device_count():
55 device = "cuda:0"
56 logging.warning(
57 'inference device is not set, using cuda:0, %s',
58 torch.cuda.get_device_name(0)
59 )
60 else:
61 device = 'cpu'
62 logging.warning(
63 (
64 'No CUDA device detected, using cpu, '
65 'expect slower speeds.'
66 )
67 )
68
69 if 'cuda' in device and not torch.cuda.is_available():
70 raise ValueError('CUDA device requested but no CUDA device detected.')
71
72 if not tokenizer:
73 raise ValueError('Tokenizer is not set in the generate function.')
74
75 if not hf_access_token:
76 raise ValueError((
77 'Hugging face access token needs to be specified. '
78 'Please refer to https://huggingface.co/docs/hub/security-tokens'
79 ' to obtain one.'
80 )
81 )
82
83 if isinstance(model, str):
84 checkpoint_path = model
85 model = AutoModelForCausalLM.from_pretrained(
86 checkpoint_path,
87 trust_remote_code=True
88 )
89 model.to(device).eval()
90 if isinstance(tokenizer, str):
91 tokenizer = AutoTokenizer.from_pretrained(
92 tokenizer,
93 token=hf_access_token,
94 )
95
96 # Speculative mode
97 draft_model = None
98 if assistant_model:
99 draft_model = assistant_model
100 if isinstance(assistant_model, str):
101 draft_model = AutoModelForCausalLM.from_pretrained(
102 assistant_model,
103 trust_remote_code=True
104 )
105 draft_model.to(device).eval()
106
107 # Prepare the prompt
108 tokenized_prompt = tokenizer(prompt)
109 tokenized_prompt = torch.tensor(
110 tokenized_prompt['input_ids'],
111 device=device
112 )
113
114 tokenized_prompt = tokenized_prompt.unsqueeze(0)
115
116 # Generate
117 stime = time.time()
118 output_ids = model.generate(
119 tokenized_prompt,
120 max_length=max_length,
121 pad_token_id=0,
122 assistant_model=draft_model,
123 **(generate_kwargs if generate_kwargs else {}),
124 )
125 generation_time = time.time() - stime
126
127 output_text = tokenizer.decode(
128 output_ids[0].tolist(),
129 skip_special_tokens=True
130 )
131
132 return output_text, generation_time
133
134
135 def openelm_generate_parser():
136 """Argument Parser"""
137
138 class KwargsParser(argparse.Action):
139 """Parser action class to parse kwargs of form key=value"""
140 def __call__(self, parser, namespace, values, option_string=None):
141 setattr(namespace, self.dest, dict())
142 for val in values:
143 if '=' not in val:
144 raise ValueError(
145 (
146 'Argument parsing error, kwargs are expected in'
147 ' the form of key=value.'
148 )
149 )
150 kwarg_k, kwarg_v = val.split('=')
151 try:
152 converted_v = int(kwarg_v)
153 except ValueError:
154 try:
155 converted_v = float(kwarg_v)
156 except ValueError:
157 converted_v = kwarg_v
158 getattr(namespace, self.dest)[kwarg_k] = converted_v
159
160 parser = argparse.ArgumentParser('OpenELM Generate Module')
161 parser.add_argument(
162 '--model',
163 dest='model',
164 help='Path to the hf converted model.',
165 required=True,
166 type=str,
167 )
168 parser.add_argument(
169 '--hf_access_token',
170 dest='hf_access_token',
171 help='Hugging face access token, starting with "hf_".',
172 type=str,
173 )
174 parser.add_argument(
175 '--prompt',
176 dest='prompt',
177 help='Prompt for LLM call.',
178 default='',
179 type=str,
180 )
181 parser.add_argument(
182 '--device',
183 dest='device',
184 help='Device used for inference.',
185 type=str,
186 )
187 parser.add_argument(
188 '--max_length',
189 dest='max_length',
190 help='Maximum length of tokens.',
191 default=256,
192 type=int,
193 )
194 parser.add_argument(
195 '--assistant_model',
196 dest='assistant_model',
197 help=(
198 (
199 'If set, this is used as a draft model '
200 'for assisted speculative generation.'
201 )
202 ),
203 type=str,
204 )
205 parser.add_argument(
206 '--generate_kwargs',
207 dest='generate_kwargs',
208 help='Additional kwargs passed to the HF generate function.',
209 type=str,
210 nargs='*',
211 action=KwargsParser,
212 )
213 return parser.parse_args()
214
215
216 if __name__ == '__main__':
217 args = openelm_generate_parser()
218 prompt = args.prompt
219
220 output_text, genertaion_time = generate(
221 prompt=prompt,
222 model=args.model,
223 device=args.device,
224 max_length=args.max_length,
225 assistant_model=args.assistant_model,
226 generate_kwargs=args.generate_kwargs,
227 hf_access_token=args.hf_access_token,
228 )
229
230 print_txt = (
231 f'\r\n{"=" * os.get_terminal_size().columns}\r\n'
232 '\033[1m Prompt + Generated Output\033[0m\r\n'
233 f'{"-" * os.get_terminal_size().columns}\r\n'
234 f'{output_text}\r\n'
235 f'{"-" * os.get_terminal_size().columns}\r\n'
236 '\r\nGeneration took'
237 f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m'
238 'seconds.\r\n'
239 )
240 print(print_txt)
241