Last active
July 1, 2024 14:23
-
-
Save halr9000/f07a866d16ccad6d23198b9118ccec16 to your computer and use it in GitHub Desktop.
Python script to caption images using microsoft/florence-2 running locally using Pinokio and Gradio. Paper page: https://huggingface.co/papers/2311.06242. Model card: https://huggingface.co/microsoft/Florence-2-large. Gradio app: https://pinokio.computer/item?uri=https://github.com/pinokiofactory/florence2.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import ast | |
| from gradio_client import Client, handle_file | |
| import json | |
| import logging | |
| logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Define the task_prompts dictionary first | |
| task_prompts = { | |
| "Caption": "<CAPTION>", | |
| "Detailed Caption": "<DETAILED_CAPTION>", | |
| "More Detailed Caption": "<MORE_DETAILED_CAPTION>", | |
| "Object Detection": "<OD>", | |
| "Dense Region Caption": "<DENSE_REGION_CAPTION>", | |
| "Region Proposal": "<REGION_PROPOSAL>", | |
| "Caption to Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>", | |
| "Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>", | |
| "Region to Segmentation": "<REGION_TO_SEGMENTATION>", | |
| "Open Vocabulary Detection": "<OPEN_VOCABULARY_DETECTION>", | |
| "Region to Category": "<REGION_TO_CATEGORY>", | |
| "Region to Description": "<REGION_TO_DESCRIPTION>", | |
| "OCR": "<OCR>", | |
| "OCR with Region": "<OCR_WITH_REGION>" | |
| } | |
| def main(image_url: str, task: str, model_id: str, client_url: str): | |
| """ | |
| Main function to process image captioning. | |
| Parameters: | |
| - image_url: URL of the image to caption. | |
| - task_prompt_key: Key corresponding to the desired task prompt. | |
| - model_id: Model ID to use for prediction. | |
| - client_url: URL of the client server. | |
| """ | |
| logger = logging.getLogger(__name__) | |
| # Directly use task as the task_prompt | |
| task_prompt = task | |
| logger.debug(f"Using task prompt key: {task_prompt}") | |
| client = Client(client_url, verbose=False) | |
| logger.debug(f"Client instantiated with URL: {client_url}") | |
| # Perform the prediction | |
| try: | |
| result = client.predict( | |
| image=handle_file(image_url), | |
| task_prompt=task_prompt, | |
| text_input=None, | |
| model_id=model_id, | |
| ) | |
| logger.debug(f"Prediction result: {result}") | |
| except Exception as e: | |
| logger.error(f"Exception occurred during prediction: {e}", exc_info=True) | |
| return | |
| result_string, _ = result | |
| # Initialize result_dict to None | |
| result_dict = None | |
| if result_string is not None and result_string.strip()!= "": | |
| try: | |
| result_dict = ast.literal_eval(result_string) | |
| logger.debug(f"Evaluating result string: {result_string}") | |
| except Exception as e: | |
| logger.error(f"Exception occurred while evaluating result string: {e}", exc_info=True) | |
| return | |
| else: | |
| json_output = json.dumps(result_dict, indent=None) | |
| # logger.info(json_output) | |
| # Load the JSON string into a Python dictionary | |
| json_output_dict = json.loads(json_output) | |
| # Process each value in the dictionary to remove newline characters | |
| for key in json_output_dict.keys(): | |
| # Remove newline characters in the value entirely | |
| json_output_dict[key] = json_output_dict[key].replace('\n', '').replace('\\n', '') | |
| # Convert the modified dictionary back into a JSON string without formatting | |
| modified_json_output_str = json.dumps(json_output_dict, indent=None) | |
| # Print the resulting JSON string | |
| print(modified_json_output_str) | |
| else: | |
| logger.warning("Result string is empty or None.") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Image Captioning Script') | |
| parser.add_argument('--image_url', type=str, required=True, help='URL of the image to caption') | |
| parser.add_argument('--task', type=str, choices=list(task_prompts.keys()), required=True, help='Image processing task supported by the model') | |
| parser.add_argument('--model_id', type=str, default="microsoft/Florence-2-large", help='Model ID to use for prediction') | |
| parser.add_argument('--client_url', type=str, default="http://100.107.248.20:42421/", help='Client server URL') | |
| args = parser.parse_args() | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Running with arguments: {args}") | |
| main(**vars(args)) # Using vars() to convert args to dict | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment