batch_infer.py
5.2 KB · 130 lines · python Raw
1 #!/usr/bin/env python3
2 """Minimal batch inference CLI for the LocateAnything-3B release code.
3
4 Examples:
5 python batch_infer.py --model /path/to/LocateAnything-3B --attn sdpa \
6 --image demo.jpg --query "person</c>car"
7
8 python batch_infer.py --requests requests.jsonl --batch-size 16 --attn la_flash
9
10 Each JSONL request should contain {"image": "/path/to.jpg", "query": "person</c>car"}.
11 """
12 import argparse
13 import json
14 import os
15 from pathlib import Path
16
17 from PIL import Image
18
19
20 def _attn_arg(value):
21 mode = (value or "sdpa").strip().lower().replace("-", "_")
22 aliases = {
23 "": "sdpa",
24 "manual": "eager",
25 "torch": "eager",
26 "torch_eager": "eager",
27 "torch_sdpa": "sdpa",
28 "flash": "la_flash",
29 "la_flash": "la_flash",
30 "kernel": "la_flash",
31 "cuda": "la_flash",
32 "range": "la_flash",
33 "range_attention": "la_flash",
34 }
35 mode = aliases.get(mode, mode)
36 if mode not in {"sdpa", "eager", "magi", "la_flash"}:
37 raise argparse.ArgumentTypeError(
38 f"--attn must be one of sdpa, eager, magi, la_flash; got {value!r}"
39 )
40 return mode
41
42
43 def _load_requests(args):
44 requests = []
45 if args.requests:
46 with open(args.requests, "r", encoding="utf-8") as f:
47 for line in f:
48 if not line.strip():
49 continue
50 row = json.loads(line)
51 requests.append((row["image"], row["query"]))
52 if args.image or args.query:
53 if len(args.image or []) != len(args.query or []):
54 raise ValueError("--image and --query must appear the same number of times")
55 requests.extend(zip(args.image, args.query))
56 if not requests:
57 raise ValueError("provide --requests JSONL or at least one --image/--query pair")
58 return requests
59
60
61 def main():
62 ap = argparse.ArgumentParser()
63 ap.add_argument("--requests", help="JSONL file with image/query fields")
64 ap.add_argument("--image", action="append", help="Image path; repeat with --query")
65 ap.add_argument("--query", action="append", help="Category query, e.g. person</c>car")
66 ap.add_argument("--model", default=os.environ.get("LA_FLASH_MODEL", "nvidia/LocateAnything-3B"))
67 ap.add_argument("--attn", type=_attn_arg, default=os.environ.get("LA_FLASH_ATTN", "sdpa"),
68 help="LLM attention backend: sdpa, eager, magi, or la_flash")
69 ap.add_argument("--vision-attn", default=os.environ.get("LA_FLASH_VISION_ATTN", "auto"),
70 choices=["auto", "flash_attention_2", "sdpa", "eager"])
71 ap.add_argument("--batch-size", type=int, default=1)
72 ap.add_argument("--scheduler", default=os.environ.get("LA_FLASH_HYBRID_SCHEDULER", "eager"),
73 choices=["eager", "hold_ar", "ar_first", "pipeline", "adaptive"])
74 ap.add_argument("--group-size", type=int, default=int(os.environ.get("LA_FLASH_HYBRID_GROUP_SIZE", "0")))
75 ap.add_argument("--max-new-tokens", type=int, default=2048)
76 ap.add_argument("--temperature", type=float, default=0.7)
77 ap.add_argument("--top-p", type=float, default=0.9)
78 ap.add_argument("--top-k", type=int, default=0)
79 ap.add_argument("--repetition-penalty", type=float, default=1.1)
80 ap.add_argument("--strict-attn", action="store_true",
81 help="Fail instead of falling back to SDPA if magi/la_flash is unavailable")
82 ap.add_argument("--out", default="", help="Optional output JSONL path; stdout if omitted")
83 args = ap.parse_args()
84 args.attn = _attn_arg(args.attn)
85
86 os.environ["LA_FLASH_MODEL"] = args.model
87 os.environ["LA_FLASH_ATTN"] = args.attn
88 os.environ["LA_FLASH_VISION_ATTN"] = args.vision_attn
89 os.environ["LA_FLASH_HYBRID_SCHEDULER"] = args.scheduler
90 os.environ["LA_FLASH_HYBRID_GROUP_SIZE"] = str(args.group_size)
91 if args.strict_attn:
92 os.environ["LA_FLASH_STRICT_ATTN"] = "1"
93
94 from batch_utils import generate_batch_hybrid, get_last_hybrid_stats, load
95 from batch_utils.hybrid_runtime import load_pil
96
97 requests = _load_requests(args)
98 load()
99
100 writer = open(args.out, "w", encoding="utf-8") if args.out else None
101 try:
102 for start in range(0, len(requests), max(1, args.batch_size)):
103 chunk = requests[start:start + max(1, args.batch_size)]
104 pairs = [(load_pil(image), query) for image, query in chunk]
105 texts = generate_batch_hybrid(
106 pairs,
107 temperature=args.temperature,
108 top_p=None if args.top_p < 0 else args.top_p,
109 top_k=None if args.top_k <= 0 else args.top_k,
110 repetition_penalty=args.repetition_penalty,
111 max_new_tokens=args.max_new_tokens,
112 scheduler=args.scheduler,
113 group_size=args.group_size,
114 )
115 stats = get_last_hybrid_stats()
116 for (image, query), text in zip(chunk, texts):
117 row = {"image": str(Path(image)), "query": query, "raw_response": text, "stats": stats}
118 line = json.dumps(row, ensure_ascii=False)
119 if writer:
120 writer.write(line + "\n")
121 else:
122 print(line, flush=True)
123 finally:
124 if writer:
125 writer.close()
126
127
128 if __name__ == "__main__":
129 main()
130