generate_openelm.py
| 1 | # |
| 2 | # For licensing see accompanying LICENSE file. |
| 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. |
| 4 | # |
| 5 | |
| 6 | """Module to generate OpenELM output given a model and an input prompt.""" |
| 7 | import os |
| 8 | import logging |
| 9 | import time |
| 10 | import argparse |
| 11 | from typing import Optional, Union |
| 12 | import torch |
| 13 | |
| 14 | from transformers import AutoTokenizer, AutoModelForCausalLM |
| 15 | |
| 16 | |
| 17 | def generate( |
| 18 | prompt: str, |
| 19 | model: Union[str, AutoModelForCausalLM], |
| 20 | hf_access_token: str = None, |
| 21 | tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf', |
| 22 | device: Optional[str] = None, |
| 23 | max_length: int = 1024, |
| 24 | assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None, |
| 25 | generate_kwargs: Optional[dict] = None, |
| 26 | ) -> str: |
| 27 | """ Generates output given a prompt. |
| 28 | |
| 29 | Args: |
| 30 | prompt: The string prompt. |
| 31 | model: The LLM Model. If a string is passed, it should be the path to |
| 32 | the hf converted checkpoint. |
| 33 | hf_access_token: Hugging face access token. |
| 34 | tokenizer: Tokenizer instance. If model is set as a string path, |
| 35 | the tokenizer will be loaded from the checkpoint. |
| 36 | device: String representation of device to run the model on. If None |
| 37 | and cuda available it would be set to cuda:0 else cpu. |
| 38 | max_length: Maximum length of tokens, input prompt + generated tokens. |
| 39 | assistant_model: If set, this model will be used for |
| 40 | speculative generation. If a string is passed, it should be the |
| 41 | path to the hf converted checkpoint. |
| 42 | generate_kwargs: Extra kwargs passed to the hf generate function. |
| 43 | |
| 44 | Returns: |
| 45 | output_text: output generated as a string. |
| 46 | generation_time: generation time in seconds. |
| 47 | |
| 48 | Raises: |
| 49 | ValueError: If device is set to CUDA but no CUDA device is detected. |
| 50 | ValueError: If tokenizer is not set. |
| 51 | ValueError: If hf_access_token is not specified. |
| 52 | """ |
| 53 | if not device: |
| 54 | if torch.cuda.is_available() and torch.cuda.device_count(): |
| 55 | device = "cuda:0" |
| 56 | logging.warning( |
| 57 | 'inference device is not set, using cuda:0, %s', |
| 58 | torch.cuda.get_device_name(0) |
| 59 | ) |
| 60 | else: |
| 61 | device = 'cpu' |
| 62 | logging.warning( |
| 63 | ( |
| 64 | 'No CUDA device detected, using cpu, ' |
| 65 | 'expect slower speeds.' |
| 66 | ) |
| 67 | ) |
| 68 | |
| 69 | if 'cuda' in device and not torch.cuda.is_available(): |
| 70 | raise ValueError('CUDA device requested but no CUDA device detected.') |
| 71 | |
| 72 | if not tokenizer: |
| 73 | raise ValueError('Tokenizer is not set in the generate function.') |
| 74 | |
| 75 | if not hf_access_token: |
| 76 | raise ValueError(( |
| 77 | 'Hugging face access token needs to be specified. ' |
| 78 | 'Please refer to https://huggingface.co/docs/hub/security-tokens' |
| 79 | ' to obtain one.' |
| 80 | ) |
| 81 | ) |
| 82 | |
| 83 | if isinstance(model, str): |
| 84 | checkpoint_path = model |
| 85 | model = AutoModelForCausalLM.from_pretrained( |
| 86 | checkpoint_path, |
| 87 | trust_remote_code=True |
| 88 | ) |
| 89 | model.to(device).eval() |
| 90 | if isinstance(tokenizer, str): |
| 91 | tokenizer = AutoTokenizer.from_pretrained( |
| 92 | tokenizer, |
| 93 | token=hf_access_token, |
| 94 | ) |
| 95 | |
| 96 | # Speculative mode |
| 97 | draft_model = None |
| 98 | if assistant_model: |
| 99 | draft_model = assistant_model |
| 100 | if isinstance(assistant_model, str): |
| 101 | draft_model = AutoModelForCausalLM.from_pretrained( |
| 102 | assistant_model, |
| 103 | trust_remote_code=True |
| 104 | ) |
| 105 | draft_model.to(device).eval() |
| 106 | |
| 107 | # Prepare the prompt |
| 108 | tokenized_prompt = tokenizer(prompt) |
| 109 | tokenized_prompt = torch.tensor( |
| 110 | tokenized_prompt['input_ids'], |
| 111 | device=device |
| 112 | ) |
| 113 | |
| 114 | tokenized_prompt = tokenized_prompt.unsqueeze(0) |
| 115 | |
| 116 | # Generate |
| 117 | stime = time.time() |
| 118 | output_ids = model.generate( |
| 119 | tokenized_prompt, |
| 120 | max_length=max_length, |
| 121 | pad_token_id=0, |
| 122 | assistant_model=draft_model, |
| 123 | **(generate_kwargs if generate_kwargs else {}), |
| 124 | ) |
| 125 | generation_time = time.time() - stime |
| 126 | |
| 127 | output_text = tokenizer.decode( |
| 128 | output_ids[0].tolist(), |
| 129 | skip_special_tokens=True |
| 130 | ) |
| 131 | |
| 132 | return output_text, generation_time |
| 133 | |
| 134 | |
| 135 | def openelm_generate_parser(): |
| 136 | """Argument Parser""" |
| 137 | |
| 138 | class KwargsParser(argparse.Action): |
| 139 | """Parser action class to parse kwargs of form key=value""" |
| 140 | def __call__(self, parser, namespace, values, option_string=None): |
| 141 | setattr(namespace, self.dest, dict()) |
| 142 | for val in values: |
| 143 | if '=' not in val: |
| 144 | raise ValueError( |
| 145 | ( |
| 146 | 'Argument parsing error, kwargs are expected in' |
| 147 | ' the form of key=value.' |
| 148 | ) |
| 149 | ) |
| 150 | kwarg_k, kwarg_v = val.split('=') |
| 151 | try: |
| 152 | converted_v = int(kwarg_v) |
| 153 | except ValueError: |
| 154 | try: |
| 155 | converted_v = float(kwarg_v) |
| 156 | except ValueError: |
| 157 | converted_v = kwarg_v |
| 158 | getattr(namespace, self.dest)[kwarg_k] = converted_v |
| 159 | |
| 160 | parser = argparse.ArgumentParser('OpenELM Generate Module') |
| 161 | parser.add_argument( |
| 162 | '--model', |
| 163 | dest='model', |
| 164 | help='Path to the hf converted model.', |
| 165 | required=True, |
| 166 | type=str, |
| 167 | ) |
| 168 | parser.add_argument( |
| 169 | '--hf_access_token', |
| 170 | dest='hf_access_token', |
| 171 | help='Hugging face access token, starting with "hf_".', |
| 172 | type=str, |
| 173 | ) |
| 174 | parser.add_argument( |
| 175 | '--prompt', |
| 176 | dest='prompt', |
| 177 | help='Prompt for LLM call.', |
| 178 | default='', |
| 179 | type=str, |
| 180 | ) |
| 181 | parser.add_argument( |
| 182 | '--device', |
| 183 | dest='device', |
| 184 | help='Device used for inference.', |
| 185 | type=str, |
| 186 | ) |
| 187 | parser.add_argument( |
| 188 | '--max_length', |
| 189 | dest='max_length', |
| 190 | help='Maximum length of tokens.', |
| 191 | default=256, |
| 192 | type=int, |
| 193 | ) |
| 194 | parser.add_argument( |
| 195 | '--assistant_model', |
| 196 | dest='assistant_model', |
| 197 | help=( |
| 198 | ( |
| 199 | 'If set, this is used as a draft model ' |
| 200 | 'for assisted speculative generation.' |
| 201 | ) |
| 202 | ), |
| 203 | type=str, |
| 204 | ) |
| 205 | parser.add_argument( |
| 206 | '--generate_kwargs', |
| 207 | dest='generate_kwargs', |
| 208 | help='Additional kwargs passed to the HF generate function.', |
| 209 | type=str, |
| 210 | nargs='*', |
| 211 | action=KwargsParser, |
| 212 | ) |
| 213 | return parser.parse_args() |
| 214 | |
| 215 | |
| 216 | if __name__ == '__main__': |
| 217 | args = openelm_generate_parser() |
| 218 | prompt = args.prompt |
| 219 | |
| 220 | output_text, genertaion_time = generate( |
| 221 | prompt=prompt, |
| 222 | model=args.model, |
| 223 | device=args.device, |
| 224 | max_length=args.max_length, |
| 225 | assistant_model=args.assistant_model, |
| 226 | generate_kwargs=args.generate_kwargs, |
| 227 | hf_access_token=args.hf_access_token, |
| 228 | ) |
| 229 | |
| 230 | print_txt = ( |
| 231 | f'\r\n{"=" * os.get_terminal_size().columns}\r\n' |
| 232 | '\033[1m Prompt + Generated Output\033[0m\r\n' |
| 233 | f'{"-" * os.get_terminal_size().columns}\r\n' |
| 234 | f'{output_text}\r\n' |
| 235 | f'{"-" * os.get_terminal_size().columns}\r\n' |
| 236 | '\r\nGeneration took' |
| 237 | f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' |
| 238 | 'seconds.\r\n' |
| 239 | ) |
| 240 | print(print_txt) |
| 241 | |