batch_infer.py
| 1 | #!/usr/bin/env python3 |
| 2 | """Minimal batch inference CLI for the LocateAnything-3B release code. |
| 3 | |
| 4 | Examples: |
| 5 | python batch_infer.py --model /path/to/LocateAnything-3B --attn sdpa \ |
| 6 | --image demo.jpg --query "person</c>car" |
| 7 | |
| 8 | python batch_infer.py --requests requests.jsonl --batch-size 16 --attn la_flash |
| 9 | |
| 10 | Each JSONL request should contain {"image": "/path/to.jpg", "query": "person</c>car"}. |
| 11 | """ |
| 12 | import argparse |
| 13 | import json |
| 14 | import os |
| 15 | from pathlib import Path |
| 16 | |
| 17 | from PIL import Image |
| 18 | |
| 19 | |
| 20 | def _attn_arg(value): |
| 21 | mode = (value or "sdpa").strip().lower().replace("-", "_") |
| 22 | aliases = { |
| 23 | "": "sdpa", |
| 24 | "manual": "eager", |
| 25 | "torch": "eager", |
| 26 | "torch_eager": "eager", |
| 27 | "torch_sdpa": "sdpa", |
| 28 | "flash": "la_flash", |
| 29 | "la_flash": "la_flash", |
| 30 | "kernel": "la_flash", |
| 31 | "cuda": "la_flash", |
| 32 | "range": "la_flash", |
| 33 | "range_attention": "la_flash", |
| 34 | } |
| 35 | mode = aliases.get(mode, mode) |
| 36 | if mode not in {"sdpa", "eager", "magi", "la_flash"}: |
| 37 | raise argparse.ArgumentTypeError( |
| 38 | f"--attn must be one of sdpa, eager, magi, la_flash; got {value!r}" |
| 39 | ) |
| 40 | return mode |
| 41 | |
| 42 | |
| 43 | def _load_requests(args): |
| 44 | requests = [] |
| 45 | if args.requests: |
| 46 | with open(args.requests, "r", encoding="utf-8") as f: |
| 47 | for line in f: |
| 48 | if not line.strip(): |
| 49 | continue |
| 50 | row = json.loads(line) |
| 51 | requests.append((row["image"], row["query"])) |
| 52 | if args.image or args.query: |
| 53 | if len(args.image or []) != len(args.query or []): |
| 54 | raise ValueError("--image and --query must appear the same number of times") |
| 55 | requests.extend(zip(args.image, args.query)) |
| 56 | if not requests: |
| 57 | raise ValueError("provide --requests JSONL or at least one --image/--query pair") |
| 58 | return requests |
| 59 | |
| 60 | |
| 61 | def main(): |
| 62 | ap = argparse.ArgumentParser() |
| 63 | ap.add_argument("--requests", help="JSONL file with image/query fields") |
| 64 | ap.add_argument("--image", action="append", help="Image path; repeat with --query") |
| 65 | ap.add_argument("--query", action="append", help="Category query, e.g. person</c>car") |
| 66 | ap.add_argument("--model", default=os.environ.get("LA_FLASH_MODEL", "nvidia/LocateAnything-3B")) |
| 67 | ap.add_argument("--attn", type=_attn_arg, default=os.environ.get("LA_FLASH_ATTN", "sdpa"), |
| 68 | help="LLM attention backend: sdpa, eager, magi, or la_flash") |
| 69 | ap.add_argument("--vision-attn", default=os.environ.get("LA_FLASH_VISION_ATTN", "auto"), |
| 70 | choices=["auto", "flash_attention_2", "sdpa", "eager"]) |
| 71 | ap.add_argument("--batch-size", type=int, default=1) |
| 72 | ap.add_argument("--scheduler", default=os.environ.get("LA_FLASH_HYBRID_SCHEDULER", "eager"), |
| 73 | choices=["eager", "hold_ar", "ar_first", "pipeline", "adaptive"]) |
| 74 | ap.add_argument("--group-size", type=int, default=int(os.environ.get("LA_FLASH_HYBRID_GROUP_SIZE", "0"))) |
| 75 | ap.add_argument("--max-new-tokens", type=int, default=2048) |
| 76 | ap.add_argument("--temperature", type=float, default=0.7) |
| 77 | ap.add_argument("--top-p", type=float, default=0.9) |
| 78 | ap.add_argument("--top-k", type=int, default=0) |
| 79 | ap.add_argument("--repetition-penalty", type=float, default=1.1) |
| 80 | ap.add_argument("--strict-attn", action="store_true", |
| 81 | help="Fail instead of falling back to SDPA if magi/la_flash is unavailable") |
| 82 | ap.add_argument("--out", default="", help="Optional output JSONL path; stdout if omitted") |
| 83 | args = ap.parse_args() |
| 84 | args.attn = _attn_arg(args.attn) |
| 85 | |
| 86 | os.environ["LA_FLASH_MODEL"] = args.model |
| 87 | os.environ["LA_FLASH_ATTN"] = args.attn |
| 88 | os.environ["LA_FLASH_VISION_ATTN"] = args.vision_attn |
| 89 | os.environ["LA_FLASH_HYBRID_SCHEDULER"] = args.scheduler |
| 90 | os.environ["LA_FLASH_HYBRID_GROUP_SIZE"] = str(args.group_size) |
| 91 | if args.strict_attn: |
| 92 | os.environ["LA_FLASH_STRICT_ATTN"] = "1" |
| 93 | |
| 94 | from batch_utils import generate_batch_hybrid, get_last_hybrid_stats, load |
| 95 | from batch_utils.hybrid_runtime import load_pil |
| 96 | |
| 97 | requests = _load_requests(args) |
| 98 | load() |
| 99 | |
| 100 | writer = open(args.out, "w", encoding="utf-8") if args.out else None |
| 101 | try: |
| 102 | for start in range(0, len(requests), max(1, args.batch_size)): |
| 103 | chunk = requests[start:start + max(1, args.batch_size)] |
| 104 | pairs = [(load_pil(image), query) for image, query in chunk] |
| 105 | texts = generate_batch_hybrid( |
| 106 | pairs, |
| 107 | temperature=args.temperature, |
| 108 | top_p=None if args.top_p < 0 else args.top_p, |
| 109 | top_k=None if args.top_k <= 0 else args.top_k, |
| 110 | repetition_penalty=args.repetition_penalty, |
| 111 | max_new_tokens=args.max_new_tokens, |
| 112 | scheduler=args.scheduler, |
| 113 | group_size=args.group_size, |
| 114 | ) |
| 115 | stats = get_last_hybrid_stats() |
| 116 | for (image, query), text in zip(chunk, texts): |
| 117 | row = {"image": str(Path(image)), "query": query, "raw_response": text, "stats": stats} |
| 118 | line = json.dumps(row, ensure_ascii=False) |
| 119 | if writer: |
| 120 | writer.write(line + "\n") |
| 121 | else: |
| 122 | print(line, flush=True) |
| 123 | finally: |
| 124 | if writer: |
| 125 | writer.close() |
| 126 | |
| 127 | |
| 128 | if __name__ == "__main__": |
| 129 | main() |
| 130 | |