| import os | |
| import json | |
| import numpy as np | |
| from datetime import datetime | |
| from copy import deepcopy | |
| from typing import ( | |
| Any, | |
| AsyncIterable, | |
| Callable, | |
| Dict, | |
| Generator, | |
| List, | |
| NamedTuple, | |
| Optional, | |
| Tuple, | |
| Union, | |
| ) | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image | |
| import torch | |
| from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights | |
| from data.transforms import ImageTransform | |
| from data.data_utils import pil_img2rgb, add_special_tokens | |
| from modeling.bagel import ( | |
| BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel | |
| ) | |
| from modeling.qwen2 import Qwen2Tokenizer | |
| from modeling.bagel.qwen2_navit import NaiveCache | |
| from modeling.autoencoder import load_ae | |
| # Set paths for your trained checkpoint | |
| # checkpoint_dir = "/scratch/by2593/merged_checkpoint_final" | |
| origin_checkpoint_dir = "/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8" | |
| checkpoint_dir = "/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_reorder_questionimage/0000150" | |
| checkpoint_dir = '/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_reorder_v2_test/000010' | |
| checkpoint_dir = '/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_reorder_v2/000150' | |
| checkpoint_dir = '/scratch/by2593/project/Bagel-Zebra-CoT/weights/checkpoints_smm_semantic_part1_v1_final/0000500' | |
| checkpoint_dir = "/scratch/by2593/hf_cache/hub/models--multimodal-reasoning-lab--Bagel-Zebra-CoT/snapshots/ebce32410ee2062d073feae484ea2c6c1515fba8" | |
| checkpoint_file = "model.safetensors" | |
| # checkpoint_file = "model_bf16.safetensors" | |
| checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file) | |
| checkpoint_path = "/scratch/by2593/Bagel-Zebra-CoT-origin/results/checkpoints_smm_semantic_part1_v1_origin/0000050/model.safetensors" | |
| print(f"Available GPUs: {torch.cuda.device_count()}") | |
| print(f"GPU memory per device:") | |
| for i in range(torch.cuda.device_count()): | |
| props = torch.cuda.get_device_properties(i) | |
| print(f" GPU {i}: {props.name}, {props.total_memory / 1e9:.1f} GB") | |
| # LLM config preparing (use base model configs) | |
| llm_config = Qwen2Config.from_json_file(os.path.join(checkpoint_dir, "llm_config.json")) | |
| llm_config.qk_norm = True | |
| llm_config.tie_word_embeddings = False | |
| llm_config.layer_module = "Qwen2MoTDecoderLayer" | |
| # ViT config preparing (use base model configs) | |
| vit_config = SiglipVisionConfig.from_json_file(os.path.join(checkpoint_dir, "vit_config.json")) | |
| vit_config.rope = False | |
| vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1 | |
| # VAE loading (use base model VAE) | |
| vae_model, vae_config = load_ae(local_path=os.path.join(origin_checkpoint_dir, "ae.safetensors")) | |
| # Bagel config preparing | |
| config = BagelConfig( | |
| visual_gen=True, | |
| visual_und=True, | |
| llm_config=llm_config, | |
| vit_config=vit_config, | |
| vae_config=vae_config, | |
| vit_max_num_patch_per_side=70, | |
| connector_act='gelu_pytorch_tanh', | |
| latent_patch_size=2, | |
| max_latent_size=64,## 默认64,改为实际的latent尺寸 | |
| ) | |
| # Import the position embedding function first | |
| from modeling.bagel.modeling_utils import get_2d_sincos_pos_embed | |
| # Create model with empty weights | |
| with init_empty_weights(): | |
| language_model = Qwen2ForCausalLM(llm_config) | |
| vit_model = SiglipVisionModel(vit_config) | |
| model = Bagel(language_model, vit_model, config) | |
| model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) | |
| # Initialize position embeddings with proper values BEFORE loading checkpoint | |
| print("Initializing position embeddings before loading...") | |
| # Initialize latent_pos_embed if it exists | |
| if hasattr(model, 'latent_pos_embed'): | |
| print("Initializing latent_pos_embed...") | |
| pos_embed = get_2d_sincos_pos_embed(model.latent_pos_embed.hidden_size, model.latent_pos_embed.max_num_patch_per_side) | |
| # Create parameter with actual values, not meta | |
| model.latent_pos_embed.pos_embed = torch.nn.Parameter( | |
| torch.from_numpy(pos_embed).float(), requires_grad=False | |
| ) | |
| print(f"latent_pos_embed initialized with shape {model.latent_pos_embed.pos_embed.shape}") | |
| # Initialize vit_pos_embed if it exists | |
| if hasattr(model, 'vit_pos_embed'): | |
| print("Initializing vit_pos_embed...") | |
| pos_embed = get_2d_sincos_pos_embed(model.vit_pos_embed.hidden_size, model.vit_pos_embed.max_num_patch_per_side) | |
| # Create parameter with actual values, not meta | |
| model.vit_pos_embed.pos_embed = torch.nn.Parameter( | |
| torch.from_numpy(pos_embed).float(), requires_grad=False | |
| ) | |
| print(f"vit_pos_embed initialized with shape {model.vit_pos_embed.pos_embed.shape}") | |
| print("Position embeddings initialized successfully") | |
| # Tokenizer Preparing (use base model tokenizer) | |
| tokenizer = Qwen2Tokenizer.from_pretrained(checkpoint_dir) | |
| tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) | |
| # Image Transform Preparing | |
| vae_transform = ImageTransform(1024, 512, 16) | |
| vit_transform = ImageTransform(980, 512, 14) | |
| # Device mapping for 8x80GB GPUs - use bf16 directly | |
| max_mem_per_gpu = "80GiB" | |
| print("Setting up device mapping...") | |
| device_map = infer_auto_device_map( | |
| model, | |
| max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())}, | |
| no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], | |
| dtype=torch.bfloat16, # Use bf16 for device mapping | |
| ) | |
| print("Device map:", device_map) | |
| # Handle same-device modules | |
| same_device_modules = [ | |
| 'language_model.model.embed_tokens', | |
| 'time_embedder', | |
| 'latent_pos_embed', | |
| 'vae2llm', | |
| 'llm2vae', | |
| 'connector', | |
| 'vit_pos_embed' | |
| ] | |
| if torch.cuda.device_count() == 1: | |
| first_device = device_map.get(same_device_modules[0], "cuda:0") | |
| for k in same_device_modules: | |
| if k in device_map: | |
| device_map[k] = first_device | |
| else: | |
| device_map[k] = "cuda:0" | |
| else: | |
| first_device = device_map.get(same_device_modules[0]) | |
| if first_device is not None: | |
| for k in same_device_modules: | |
| if k in device_map: | |
| device_map[k] = first_device | |
| print("Final device map:", device_map) | |
| # Load checkpoint directly in bf16 | |
| print(f"Loading checkpoint directly in bfloat16: {checkpoint_path}") | |
| print("Loading model from safetensors file...") | |
| # Load model directly in bf16 | |
| model = load_checkpoint_and_dispatch( | |
| model, | |
| checkpoint=checkpoint_path, | |
| device_map=device_map, | |
| offload_buffers=False, | |
| dtype=torch.bfloat16, # Load directly as bf16 | |
| force_hooks=True, | |
| ) | |
| model = model.eval() | |
| print('Model loaded directly in bfloat16!') | |
| print(f"Model dtype: {next(model.parameters()).dtype}") | |
| # Position embeddings were already initialized before model loading | |
| print("Position embeddings were pre-initialized before loading checkpoint") | |
| print("Model loading completed successfully!") | |
| # Check memory usage | |
| print("GPU memory usage after loading:") | |
| for i in range(torch.cuda.device_count()): | |
| if torch.cuda.memory_allocated(i) > 0: | |
| allocated = torch.cuda.memory_allocated(i) / 1e9 | |
| cached = torch.cuda.memory_reserved(i) / 1e9 | |
| print(f" GPU {i}: {allocated:.1f}GB allocated, {cached:.1f}GB cached") | |
| # Rest of inference code | |
| from inferencer import InterleaveInferencer | |
| inferencer = InterleaveInferencer( | |
| model=model, | |
| vae_model=vae_model, | |
| tokenizer=tokenizer, | |
| vae_transform=vae_transform, | |
| vit_transform=vit_transform, | |
| new_token_ids=new_token_ids | |
| ) | |
| import random | |
| import numpy as np | |
| seed = 42 | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| inference_hyper=dict( | |
| do_sample=False, | |
| text_temperature=0.0, | |
| cfg_text_scale=4.0, | |
| cfg_img_scale=2.0, | |
| cfg_interval=[0.0, 1.0], | |
| timestep_shift=3.0, | |
| num_timesteps=50, | |
| cfg_renorm_min=0.0, | |
| cfg_renorm_type="text_channel", | |
| ) | |
| INTERLEAVED_SYSTEM_PROMPT = '''You are an AI reasoning assistant capable of step-by-step interleaved text and visual chain of thought. Think step by step and use visual aids to enhance your problem-solving.''' | |
| INTERLEAVED_SYSTEM_PROMPT = '' | |
| # Original example (004 case) - commented out | |
| # prompt = '''My goal is to generate a visual guide for constructing a specific shape using a set of blocks. This involves multiple steps, each requiring the addition of a new block to progressively build the final shape. The initial input includes 2 images of multiple blocks that will be used <image_start>[problem_image_1]<image_end><image_start>[problem_image_2]<image_end> and an image of the final desired shape<image_start>[problem_image_3]<image_end>. I need to imagine and generate images of intermediate steps, leading up to the final construction. Step 0 has been completed: a red arch block has been placed on top of the ground. The image after step 0 is provided<image_start>[problem_image_4]<image_end>. Now I need to generate the image for step 1, considering spatial relationships and stability.''' | |
| # Use the new example data (145 case) | |
| prompt = '''Based on the construction task shown below, follow the instructions to complete the build. Given the final desired shape of blocks shown in the first image<image_start>[problem_image_1]<image_end> which is viewed from a Front45 angle, perform a series of specified manipulations. This involves multiple steps, each requiring the addition of a new block to progressively build the final shape. The initial input also includes 3 images of multiple blocks that will be used.<image_start>[problem_image_2]<image_end><image_start>[problem_image_3]<image_end><image_start>[problem_image_4]<image_end> Step 0 has been completed: a orange cylinder block has been placed on top of the ground. The image after step 0 is provided.<image_start>[problem_image_5]<image_end>''' | |
| # Load images from the new example paths (145 case) | |
| image = [] | |
| base_path = '/scratch/by2593/project/SMM' | |
| image_paths = [ | |
| f'{base_path}/semantic_blocks_part1/145/final_state/145_final_1.png', # problem_image_1 - final desired shape | |
| f'{base_path}/SMM_data/each_block_views_diffposes/cylinder_orange.png', # problem_image_2 - orange cylinder | |
| f'{base_path}/SMM_data/each_block_views_diffposes/cuboid3_yellow.png', # problem_image_3 - yellow cuboid3 | |
| f'{base_path}/SMM_data/each_block_views_diffposes/triangle_orange.png', # problem_image_4 - orange triangle | |
| f'{base_path}/semantic_blocks_part1/145/steps/view_1/145_step0_1.png', # problem_image_5 - image after step 0 | |
| ] | |
| print("Loading input images:") | |
| for i, img_path in enumerate(image_paths): | |
| try: | |
| img = Image.open(img_path).convert('RGB') | |
| image.append(img) | |
| print(f" ✓ Loaded problem_image_{i+1}: {img_path}") | |
| print(f" Image size: {img.size}") | |
| except Exception as e: | |
| print(f" ✗ Failed to load {img_path}: {e}") | |
| # Create a placeholder image if file not found | |
| img = Image.new('RGB', (512, 512), color='gray') | |
| image.append(img) | |
| print(f" ⚠ Using placeholder for problem_image_{i+1}") | |
| print(prompt) | |
| print('-'*50) | |
| # Create output folder with timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_folder = f"reasoning_output_example_145_{timestamp}" | |
| images_folder = os.path.join(output_folder, "images") | |
| os.makedirs(images_folder, exist_ok=True) | |
| print(f"Output will be saved to: {output_folder}") | |
| # Save the original problem images if they exist | |
| problem_image_paths = [] | |
| if image is not None: | |
| if isinstance(image, list): | |
| # Handle multiple images | |
| for i, img in enumerate(image): | |
| problem_image_path = os.path.join(images_folder, f"problem_image_{i+1}.png") | |
| relative_path = os.path.join("images", f"problem_image_{i+1}.png") | |
| img.save(problem_image_path) | |
| problem_image_paths.append(relative_path) | |
| print(f"Problem image {i+1} saved at '{problem_image_path}'") | |
| else: | |
| # Handle single image | |
| problem_image_path = os.path.join(images_folder, "problem_image.png") | |
| relative_path = os.path.join("images", "problem_image.png") | |
| image.save(problem_image_path) | |
| problem_image_paths.append(relative_path) | |
| print(f"Problem image saved at '{problem_image_path}'") | |
| reasoning_text = [] | |
| reasoning_images = [] | |
| generated_image_paths = [] # Store relative paths to generated reasoning images | |
| # Create input with multiple images properly flattened | |
| if image is not None: | |
| if isinstance(image, list): | |
| current_input = [prompt] + image # Flatten the list of images | |
| else: | |
| current_input = [prompt, image] | |
| else: | |
| current_input = [prompt] | |
| # Loop until no more vision_start tokens | |
| iteration = 0 | |
| while True: | |
| # Get understanding output | |
| print(f"iteration: {iteration}") | |
| output = inferencer.interleave_inference(current_input, understanding_output=True, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper) | |
| # Check for stopping conditions | |
| has_final_answer = 'Final Answer:' in output[0] or '<answer>' in output[0] | |
| # Stop if we have a final answer OR if there's no vision token (no more images to generate) | |
| # should_stop = has_final_answer or not has_vision_token | |
| should_stop = has_final_answer | |
| if should_stop: | |
| if output[0].strip(): | |
| extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1] | |
| reasoning_text.append(extracted_text) | |
| print(f"{extracted_text}") | |
| current_input = current_input + [extracted_text] | |
| break | |
| extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1] | |
| reasoning_text.append(extracted_text) | |
| print(f"{extracted_text}") | |
| # Generate image based on current reasoning | |
| current_input_with_reasoning = current_input + [extracted_text] | |
| output = inferencer.interleave_inference(current_input_with_reasoning, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper) | |
| image_output = output[0] | |
| # Save and collect the generated image | |
| reasoning_images.append(image_output) | |
| image_filename = f'reasoning_image_{iteration + 1}.png' | |
| image_path = os.path.join(images_folder, image_filename) | |
| relative_image_path = os.path.join("images", image_filename) # Relative path for JSON | |
| image_output.save(image_path) | |
| generated_image_paths.append(relative_image_path) | |
| print(f"Image saved at '{image_path}'") | |
| # Update input for next iteration | |
| current_input = current_input_with_reasoning + [image_output] | |
| iteration += 1 | |
| print('-'*50) | |
| # Save reasoning data to JSON | |
| reasoning_data = { | |
| "timestamp": timestamp, | |
| "prompt": prompt, | |
| "system_prompt": INTERLEAVED_SYSTEM_PROMPT, | |
| "problem_image_paths": problem_image_paths if problem_image_paths else None, | |
| "response": [ | |
| { | |
| "step": i + 1, | |
| "text": text, | |
| "image_path": generated_image_paths[i] if i < len(generated_image_paths) else None | |
| } | |
| for i, text in enumerate(reasoning_text) | |
| ], | |
| "total_steps": len(reasoning_text), | |
| "total_images": len(generated_image_paths) | |
| } | |
| # Save JSON file | |
| json_path = os.path.join(output_folder, "reasoning_data.json") | |
| with open(json_path, 'w', encoding='utf-8') as f: | |
| json.dump(reasoning_data, f, indent=2, ensure_ascii=False) | |
| print(f"\nReasoning complete!") | |
| print(f"Output folder: {output_folder}") | |
| print(f"JSON metadata: {json_path}") | |
| print(f"Generated {len(generated_image_paths)} images and {len(reasoning_text)} text steps") | |
| # python infz_bf16.py | |
| # import os | |
| # import json | |
| # from datetime import datetime | |
| # from copy import deepcopy | |
| # from typing import ( | |
| # Any, | |
| # AsyncIterable, | |
| # Callable, | |
| # Dict, | |
| # Generator, | |
| # List, | |
| # NamedTuple, | |
| # Optional, | |
| # Tuple, | |
| # Union, | |
| # ) | |
| # import requests | |
| # from io import BytesIO | |
| # from PIL import Image | |
| # import torch | |
| # from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights | |
| # from data.transforms import ImageTransform | |
| # from data.data_utils import pil_img2rgb, add_special_tokens | |
| # from modeling.bagel import ( | |
| # BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel | |
| # ) | |
| # from modeling.qwen2 import Qwen2Tokenizer | |
| # from modeling.bagel.qwen2_navit import NaiveCache | |
| # from modeling.autoencoder import load_ae | |
| # # Set paths for your trained checkpoint | |
| # checkpoint_dir = "path/to/your/HF_HOME/models/Bagel-Zebra-CoT" | |
| # checkpoint_file = "model_bf16.safetensors" | |
| # checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file) | |
| # print(f"Available GPUs: {torch.cuda.device_count()}") | |
| # print(f"GPU memory per device:") | |
| # for i in range(torch.cuda.device_count()): | |
| # props = torch.cuda.get_device_properties(i) | |
| # print(f" GPU {i}: {props.name}, {props.total_memory / 1e9:.1f} GB") | |
| # # LLM config preparing (use base model configs) | |
| # llm_config = Qwen2Config.from_json_file(os.path.join(checkpoint_dir, "llm_config.json")) | |
| # llm_config.qk_norm = True | |
| # llm_config.tie_word_embeddings = False | |
| # llm_config.layer_module = "Qwen2MoTDecoderLayer" | |
| # # ViT config preparing (use base model configs) | |
| # vit_config = SiglipVisionConfig.from_json_file(os.path.join(checkpoint_dir, "vit_config.json")) | |
| # vit_config.rope = False | |
| # vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1 | |
| # # VAE loading (use base model VAE) | |
| # vae_model, vae_config = load_ae(local_path=os.path.join(checkpoint_dir, "ae.safetensors")) | |
| # # Bagel config preparing | |
| # config = BagelConfig( | |
| # visual_gen=True, | |
| # visual_und=True, | |
| # llm_config=llm_config, | |
| # vit_config=vit_config, | |
| # vae_config=vae_config, | |
| # vit_max_num_patch_per_side=70, | |
| # connector_act='gelu_pytorch_tanh', | |
| # latent_patch_size=2, | |
| # max_latent_size=64, | |
| # ) | |
| # # Create model with empty weights - IMPORTANT: Use float32 initially to match checkpoint | |
| # with init_empty_weights(): | |
| # language_model = Qwen2ForCausalLM(llm_config) | |
| # vit_model = SiglipVisionModel(vit_config) | |
| # model = Bagel(language_model, vit_model, config) | |
| # model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) | |
| # # Tokenizer Preparing (use base model tokenizer) | |
| # tokenizer = Qwen2Tokenizer.from_pretrained(checkpoint_dir) | |
| # tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) | |
| # # Image Transform Preparing | |
| # vae_transform = ImageTransform(1024, 512, 16) | |
| # vit_transform = ImageTransform(980, 512, 14) | |
| # # Device mapping for 8x80GB GPUs - use bf16 directly | |
| # max_mem_per_gpu = "80GiB" | |
| # print("Setting up device mapping...") | |
| # device_map = infer_auto_device_map( | |
| # model, | |
| # max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())}, | |
| # no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], | |
| # dtype=torch.bfloat16, # Use bf16 for device mapping | |
| # ) | |
| # print("Device map:", device_map) | |
| # # Handle same-device modules | |
| # same_device_modules = [ | |
| # 'language_model.model.embed_tokens', | |
| # 'time_embedder', | |
| # 'latent_pos_embed', | |
| # 'vae2llm', | |
| # 'llm2vae', | |
| # 'connector', | |
| # 'vit_pos_embed' | |
| # ] | |
| # if torch.cuda.device_count() == 1: | |
| # first_device = device_map.get(same_device_modules[0], "cuda:0") | |
| # for k in same_device_modules: | |
| # if k in device_map: | |
| # device_map[k] = first_device | |
| # else: | |
| # device_map[k] = "cuda:0" | |
| # else: | |
| # first_device = device_map.get(same_device_modules[0]) | |
| # if first_device is not None: | |
| # for k in same_device_modules: | |
| # if k in device_map: | |
| # device_map[k] = first_device | |
| # print("Final device map:", device_map) | |
| # # Load checkpoint directly in bf16 | |
| # print(f"Loading checkpoint directly in bfloat16: {checkpoint_path}") | |
| # print("Loading model from safetensors file...") | |
| # # Load model directly in bf16 | |
| # model = load_checkpoint_and_dispatch( | |
| # model, | |
| # checkpoint=checkpoint_path, | |
| # device_map=device_map, | |
| # offload_buffers=False, | |
| # dtype=torch.bfloat16, # Load directly as bf16 | |
| # force_hooks=True, | |
| # ) | |
| # model = model.eval() | |
| # print('Model loaded directly in bfloat16!') | |
| # print(f"Model dtype: {next(model.parameters()).dtype}") | |
| # print("Model loading completed successfully!") | |
| # # Check memory usage | |
| # print("GPU memory usage after loading:") | |
| # for i in range(torch.cuda.device_count()): | |
| # if torch.cuda.memory_allocated(i) > 0: | |
| # allocated = torch.cuda.memory_allocated(i) / 1e9 | |
| # cached = torch.cuda.memory_reserved(i) / 1e9 | |
| # print(f" GPU {i}: {allocated:.1f}GB allocated, {cached:.1f}GB cached") | |
| # # Rest of inference code | |
| # from inferencer import InterleaveInferencer | |
| # inferencer = InterleaveInferencer( | |
| # model=model, | |
| # vae_model=vae_model, | |
| # tokenizer=tokenizer, | |
| # vae_transform=vae_transform, | |
| # vit_transform=vit_transform, | |
| # new_token_ids=new_token_ids | |
| # ) | |
| # import random | |
| # import numpy as np | |
| # seed = 42 | |
| # random.seed(seed) | |
| # np.random.seed(seed) | |
| # torch.manual_seed(seed) | |
| # if torch.cuda.is_available(): | |
| # torch.cuda.manual_seed(seed) | |
| # torch.cuda.manual_seed_all(seed) | |
| # torch.backends.cudnn.deterministic = True | |
| # torch.backends.cudnn.benchmark = False | |
| # inference_hyper=dict( | |
| # do_sample=True, | |
| # text_temperature=0.3, | |
| # cfg_text_scale=4.0, | |
| # cfg_img_scale=2.0, | |
| # cfg_interval=[0.0, 1.0], | |
| # timestep_shift=3.0, | |
| # num_timesteps=50, | |
| # cfg_renorm_min=0.0, | |
| # cfg_renorm_type="text_channel", | |
| # ) | |
| # INTERLEAVED_SYSTEM_PROMPT = '''You are an AI reasoning assistant capable of step-by-step interleaved text and visual chain of thought. Think step by step and use visual aids to enhance your problem-solving. Provide your final conclusion clearly in the format of "Final Answer: <answer here>"''' | |
| # prompt = '''Subtract all cylinders. Add 1 red sphere. How many objects are left?''' | |
| # image = Image.open('test_images/image.png') | |
| # print(prompt) | |
| # print('-'*50) | |
| # # Create output folder with timestamp | |
| # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # output_folder = f"reasoning_output_{timestamp}" | |
| # images_folder = os.path.join(output_folder, "images") | |
| # os.makedirs(images_folder, exist_ok=True) | |
| # # Save the original problem images if they exist | |
| # problem_image_paths = [] | |
| # if image is not None: | |
| # if isinstance(image, list): | |
| # # Handle multiple images | |
| # for i, img in enumerate(image): | |
| # problem_image_path = os.path.join(images_folder, f"problem_image_{i+1}.png") | |
| # relative_path = os.path.join("images", f"problem_image_{i+1}.png") | |
| # img.save(problem_image_path) | |
| # problem_image_paths.append(relative_path) | |
| # print(f"Problem image {i+1} saved at '{problem_image_path}'") | |
| # else: | |
| # # Handle single image | |
| # problem_image_path = os.path.join(images_folder, "problem_image.png") | |
| # relative_path = os.path.join("images", "problem_image.png") | |
| # image.save(problem_image_path) | |
| # problem_image_paths.append(relative_path) | |
| # print(f"Problem image saved at '{problem_image_path}'") | |
| # reasoning_text = [] | |
| # reasoning_images = [] | |
| # image_paths = [] # Store relative paths to images | |
| # # Create input with multiple images properly flattened | |
| # if image is not None: | |
| # if isinstance(image, list): | |
| # current_input = [prompt] + image # Flatten the list of images | |
| # else: | |
| # current_input = [prompt, image] | |
| # else: | |
| # current_input = [prompt] | |
| # # Loop until no more vision_start tokens | |
| # iteration = 0 | |
| # while True: | |
| # # Get understanding output | |
| # print(f"iteration: {iteration}") | |
| # output = inferencer.interleave_inference(current_input, understanding_output=True, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper) | |
| # # Check for stopping conditions | |
| # has_final_answer = 'Final Answer:' in output[0] or '<answer>' in output[0] | |
| # # Stop if we have a final answer OR if there's no vision token (no more images to generate) | |
| # # should_stop = has_final_answer or not has_vision_token | |
| # should_stop = has_final_answer | |
| # if should_stop: | |
| # if output[0].strip(): | |
| # extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1] | |
| # reasoning_text.append(extracted_text) | |
| # print(f"{extracted_text}") | |
| # current_input = current_input + [extracted_text] | |
| # break | |
| # extracted_text = output[0].split('<|im_end|>')[0].split('<|im_start|>')[1] | |
| # reasoning_text.append(extracted_text) | |
| # print(f"{extracted_text}") | |
| # # Generate image based on current reasoning | |
| # current_input_with_reasoning = current_input + [extracted_text] | |
| # output = inferencer.interleave_inference(current_input_with_reasoning, system_prompt=INTERLEAVED_SYSTEM_PROMPT, **inference_hyper) | |
| # image_output = output[0] | |
| # # Save and collect the generated image | |
| # reasoning_images.append(image_output) | |
| # image_filename = f'reasoning_image_{iteration + 1}.png' | |
| # image_path = os.path.join(images_folder, image_filename) | |
| # relative_image_path = os.path.join("images", image_filename) # Relative path for JSON | |
| # image_output.save(image_path) | |
| # image_paths.append(relative_image_path) | |
| # print(f"Image saved at '{image_path}'") | |
| # # Update input for next iteration | |
| # current_input = current_input_with_reasoning + [image_output] | |
| # iteration += 1 | |
| # print('-'*50) | |
| # # Save reasoning data to JSON | |
| # reasoning_data = { | |
| # "timestamp": timestamp, | |
| # "prompt": prompt, | |
| # "system_prompt": INTERLEAVED_SYSTEM_PROMPT, | |
| # "problem_image_paths": problem_image_paths if problem_image_paths else None, | |
| # "response": [ | |
| # { | |
| # "step": i + 1, | |
| # "text": text, | |
| # "image_path": image_paths[i] if i < len(image_paths) else None | |
| # } | |
| # for i, text in enumerate(reasoning_text) | |
| # ], | |
| # "total_steps": len(reasoning_text), | |
| # "total_images": len(image_paths) | |
| # } | |
| # # Save JSON file | |
| # json_path = os.path.join(output_folder, "reasoning_data.json") | |
| # with open(json_path, 'w', encoding='utf-8') as f: | |
| # json.dump(reasoning_data, f, indent=2, ensure_ascii=False) | |
| # print(f"\nReasoning complete!") | |
| # print(f"Output folder: {output_folder}") | |
| # print(f"JSON metadata: {json_path}") | |
| # print(f"Generated {len(image_paths)} images and {len(reasoning_text)} text steps") | |