基于Qwen2-VL模型针对LaTeX OCR任务进行微调训练 - 多图推理

flyfish

基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_LoRA配置如何写
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_单图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_原模型_单图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_原模型_多图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_多图推理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_数据处理
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_训练
基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_训练过程
输入两张图像

请添加图片描述
请添加图片描述
输出
请添加图片描述
可视化
Image 1:
E m m ˉ = 2 7 Q c π 1 / 2 Γ ( 1 / 4 ) 2 log ⁡ ( L 0 / L ) L ∫ 1 ∞ d y y 2 y 4 − 1 . E _ { m \bar { m } } = \frac { 2 ^ { 7 } \sqrt { Q _ { c } } \pi ^ { 1 / 2 } } { \Gamma ( 1 / 4 ) ^ { 2 } } \frac { \log \left( L _ { 0 } / L \right) } { L } \int _ { 1 } ^ { \infty } d y \frac { y ^ { 2 } } { \sqrt { y ^ { 4 } - 1 } } . Emmˉ=Γ(1/4)227Qc π1/2Llog(L0/L)1dyy41 y2.

Image 2:
u ( τ ) ‾ = u ( − τ ˉ ) , u ( τ + 1 ) = − u ( τ ) , \overline { { u ( \tau ) } } = u ( - \bar { \tau } ) , \qquad \qquad u ( \tau + 1 ) = - u ( \tau ) , u(τ)=u(τˉ),u(τ+1)=u(τ),

import argparse
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from peft import PeftModel, LoraConfig, TaskType
import torch

class LaTeXOCR:
    def __init__(self, local_model_path, lora_model_path):
        self.local_model_path = local_model_path
        self.lora_model_path = lora_model_path
        self._load_model_and_processor()

    def _load_model_and_processor(self):
        config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=[
                "q_proj", "k_proj", "v_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj",
            ],
            inference_mode=True,
            r=64,
            lora_alpha=16,
            lora_dropout=0.05,
            bias="none",
        )

        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.local_model_path, torch_dtype=torch.float16, device_map="auto"
        )
        self.model = PeftModel.from_pretrained(
            self.model, self.lora_model_path, config=config
        )
        self.processor = AutoProcessor.from_pretrained(self.local_model_path)

    def generate_latex_from_images(self, test_image_paths, prompt):
        """
        根据给定的测试图像路径列表和提示信息,生成对应的LaTeX格式文本。

        参数:
            test_image_paths (list of str): 包含数学公式的测试图像路径列表。
            prompt (str): 提供给模型的提示信息。

        返回:
            list of str: 转换后的LaTeX格式文本列表。
        """
        results = []
        for image_path in test_image_paths:
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image_path,
                            "resized_height": 100,
                            "resized_width": 500,
                        },
                        {"type": "text", "text": prompt},
                    ],
                }
            ]

            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")

            with torch.no_grad():
                generated_ids = self.model.generate(**inputs, max_new_tokens=8192)

            generated_ids_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = self.processor.batch_decode(
                generated_ids_trimmed,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )

            results.append(output_text[0])
        
        return results


def parse_arguments():
    parser = argparse.ArgumentParser(description="LaTeX OCR using Qwen2-VL")

    parser.add_argument(
        "--local_model_path",
        type=str,
        default="./Qwen/Qwen2-VL-7B-Instruct",
        help='Path to the local model.',
    )
    parser.add_argument(
        "--lora_model_path",
        type=str,
        default="./output/Qwen2-VL-7B-LatexOCR/checkpoint-1500",
        help='Path to the LoRA model checkpoint.',
    )
    parser.add_argument(
        "--test_image_paths",
        nargs='+',  # 接受多个参数
        type=str,
        default=["./LaTeX_OCR/987.jpg", "./LaTeX_OCR/986.jpg"],  # 设置默认值为两个图像路径
        help='Paths to the test images.',
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()

    prompt = (
        "尊敬的Qwen2VL大模型,我需要你帮助我将一张包含数学公式的图片转换成LaTeX格式的文本。\n"
        "请按照以下说明进行操作:\n"
        "1. **图像中的内容**: 图像中包含的是一个或多个数学公式,请确保准确地识别并转换为LaTeX代码。\n"
        "2. **公式识别**: 请专注于识别和转换数学符号、希腊字母、积分、求和、分数、指数等数学元素。\n"
        "3. **LaTeX语法**: 输出时使用标准的LaTeX语法。确保所有的命令都是正确的,并且可以被LaTeX编译器正确解析。\n"
        "4. **结构保持**: 如果图像中的公式有特定的结构(例如多行公式、矩阵、方程组),请在输出的LaTeX代码中保留这些结构。\n"
        "5. **上下文无关**: 不要尝试解释公式的含义或者添加额外的信息,只需严格按照图像内容转换。\n"
        "6. **格式化**: 如果可能的话,使输出的LaTeX代码易于阅读,比如适当添加空格和换行。"
    )

    latex_ocr = LaTeXOCR(args.local_model_path, args.lora_model_path)
    results = latex_ocr.generate_latex_from_images(args.test_image_paths, prompt)

    for i, result in enumerate(results):
        print(f"Image {i + 1}:")
        print(result)
        print("-" * 80)
Logo

欢迎加入 MCP 技术社区!与志同道合者携手前行,一同解锁 MCP 技术的无限可能!

更多推荐