import os import cv2 import numpy as np from PIL import Image, ImageDraw, ImageFont from ultralytics import YOLO import sqlite3 from io import BytesIO from scipy.stats import norm # Load YOLO models try: yolo_model_cataract = YOLO('best-cataract-seg.pt') yolo_model_object_detection = YOLO('best-cataract-od.pt') print("YOLO models loaded successfully.") except Exception as e: print(f"Error loading YOLO models: {e}") def calculate_ratios(red_values, green_values, blue_values, total_pixels): if total_pixels == 0: return 0, 0, 0 red_ratio = np.sum(red_values) / total_pixels green_ratio = np.sum(green_values) / total_pixels blue_ratio = np.sum(blue_values) / total_pixels total_ratio = red_ratio + green_ratio + blue_ratio if total_ratio > 0: red_quantity = (red_ratio / total_ratio) * 255 green_quantity = (green_ratio / total_ratio) * 255 blue_quantity = (blue_ratio / total_ratio) * 255 else: red_quantity, green_quantity, blue_quantity = 0, 0, 0 return red_quantity, green_quantity, blue_quantity def cataract_staging(red_quantity, green_quantity, blue_quantity): # Assuming you have already defined your mean and std for each class and each RGB channel # Example mean and std based on earlier discussion mean_mature_red = 73.37 std_mature_red = (90.12 - 41.49) / 4 mean_mature_green = 89.48 std_mature_green = (97.67 - 83.39) / 4 mean_mature_blue = 92.15 std_mature_blue = (117.82 - 75.37) / 4 mean_normal_red = 67.84 std_normal_red = (107.02 - 56.19) / 4 mean_normal_green = 84.85 std_normal_green = (89.89 - 80.74) / 4 mean_normal_blue = 102.31 std_normal_blue = (111.34 - 65.58) / 4 mean_immature_red = 68.83 std_immature_red = (85.95 - 41.49) / 4 mean_immature_green = 89.43 std_immature_green = (97.67 - 83.39) / 4 mean_immature_blue = 96.74 std_immature_blue = (117.82 - 78.41) / 4 # Calculate likelihoods for each class likelihood_mature = ( norm.pdf(red_quantity, mean_mature_red, std_mature_red) * norm.pdf(green_quantity, mean_mature_green, std_mature_green) * norm.pdf(blue_quantity, mean_mature_blue, std_mature_blue) ) likelihood_normal = ( norm.pdf(red_quantity, mean_normal_red, std_normal_red) * norm.pdf(green_quantity, mean_normal_green, std_normal_green) * norm.pdf(blue_quantity, mean_normal_blue, std_normal_blue) ) likelihood_immature = ( norm.pdf(red_quantity, mean_immature_red, std_immature_red) * norm.pdf(green_quantity, mean_immature_green, std_immature_green) * norm.pdf(blue_quantity, mean_immature_blue, std_immature_blue) ) # Define prior probabilities (assuming equal prior for simplicity) prior_mature = 1/3 prior_normal = 1/3 prior_immature = 1/3 # Apply Bayes' theorem to compute posterior probabilities posterior_mature = likelihood_mature * prior_mature posterior_normal = likelihood_normal * prior_normal posterior_immature = likelihood_immature * prior_immature # Determine the stage based on maximum posterior probability stages = { posterior_mature: "Mature", posterior_normal: "Normal", posterior_immature: "Immature" } max_posterior = max(posterior_mature, posterior_normal, posterior_immature) stage = stages[max_posterior] return stage def add_watermark(image): try: logo = Image.open('image-logo.png').convert("RGBA") image = image.convert("RGBA") # Resize logo basewidth = 100 wpercent = (basewidth / float(logo.size[0])) hsize = int((float(wpercent) * logo.size[1])) logo = logo.resize((basewidth, hsize), Image.LANCZOS) # Position logo position = (image.width - logo.width - 10, image.height - logo.height - 10) # Composite image transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0)) transparent.paste(image, (0, 0)) transparent.paste(logo, position, mask=logo) return transparent.convert("RGB") except Exception as e: print(f"Error adding watermark: {e}") return image def predict_and_visualize(image): try: pil_image = Image.fromarray(image.astype('uint8'), 'RGB') orig_size = pil_image.size results = yolo_model_cataract(pil_image) raw_response = str(results) masked_image = np.array(pil_image) mask_image = np.zeros_like(masked_image) red_quantity, green_quantity, blue_quantity = 0, 0, 0 total_pixels = 0 if len(results) > 0: result = results[0] if hasattr(result, 'masks') and result.masks is not None and len(result.masks) > 0: mask = np.array(result.masks.data.cpu().squeeze().numpy()) mask_resized = np.array(Image.fromarray(mask).resize(orig_size, Image.NEAREST)) red_mask = np.zeros_like(masked_image) red_mask[mask_resized > 0.5] = [255, 0, 0] alpha = 0.5 blended_image = cv2.addWeighted(masked_image, 1 - alpha, red_mask, alpha, 0) pupil_pixels = np.array(pil_image)[mask_resized > 0.5] total_pixels = pupil_pixels.shape[0] red_values = pupil_pixels[:, 0] green_values = pupil_pixels[:, 1] blue_values = pupil_pixels[:, 2] red_quantity, green_quantity, blue_quantity = calculate_ratios(red_values, green_values, blue_values, total_pixels) stage = cataract_staging(red_quantity, green_quantity, blue_quantity) # Add text to the blended image combined_pil_image = Image.fromarray(blended_image) draw = ImageDraw.Draw(combined_pil_image) # Load a larger font (adjust the size as needed) font_size = 48 # Example font size try: font = ImageFont.truetype("font.ttf", size=font_size) except IOError: font = ImageFont.load_default() print("Error: cannot open resource, using default font.") text = f"Red quantity: {red_quantity:.2f}\nGreen quantity: {green_quantity:.2f}\nBlue quantity: {blue_quantity:.2f}\nStage: {stage}" # Calculate text bounding box text_bbox = draw.textbbox((0, 0), text, font=font) text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] text_x = 20 text_y = 40 padding = 10 # Draw a filled rectangle for the background draw.rectangle( [text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding], fill="black" ) # Draw text on top of the rectangle draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font) # Add watermark to the image combined_pil_image_with_watermark = add_watermark(combined_pil_image) return np.array(combined_pil_image_with_watermark), red_quantity, green_quantity, blue_quantity, raw_response, stage return image, 0, 0, 0, "No mask detected.", "Unknown" except Exception as e: print("Error:", e) return np.zeros_like(image), 0, 0, 0, str(e), "Error" def check_duplicate_entry(conn, red_quantity, green_quantity, blue_quantity, stage): cursor = conn.cursor() query = '''SELECT COUNT(*) FROM cataract_results WHERE red_quantity=? AND green_quantity=? AND blue_quantity=? AND stage=?''' cursor.execute(query, (red_quantity, green_quantity, blue_quantity, stage)) count = cursor.fetchone()[0] return count > 0 def save_cataract_prediction_to_db(image, red_quantity, green_quantity, blue_quantity, stage): database = "cataract_results.db" conn = create_connection(database) if conn: create_cataract_table(conn) # Check for duplicate entries if check_duplicate_entry(conn, red_quantity, green_quantity, blue_quantity, stage): conn.close() return "Duplicate entry found, not saving.", "Duplicate entry detected." sql = '''INSERT INTO cataract_results(image, red_quantity, green_quantity, blue_quantity, stage) VALUES(?,?,?,?,?)''' cur = conn.cursor() # Convert the image to bytes buffered = BytesIO() image.save(buffered, format="PNG") img_bytes = buffered.getvalue() cur.execute(sql, (img_bytes, red_quantity, green_quantity, blue_quantity, stage)) conn.commit() conn.close() return "Data saved successfully", f"Red: {red_quantity}, Green: {green_quantity}, Blue: {blue_quantity}, Stage: {stage}" return "Failed to save data", "No connection to the database." def combined_prediction(image): blended_image, red_quantity, green_quantity, blue_quantity, raw_response, stage = predict_and_visualize(image) save_message, debug_info = save_cataract_prediction_to_db(Image.fromarray(blended_image), red_quantity, green_quantity, blue_quantity, stage) return blended_image, red_quantity, green_quantity, blue_quantity, raw_response, stage, save_message, debug_info def create_connection(db_file): """ Create a database connection to the SQLite database """ conn = None try: conn = sqlite3.connect(db_file) return conn except sqlite3.Error as e: print(e) return conn def create_cataract_table(conn): """ Create the cataract results table if it does not exist """ create_table_sql = """ CREATE TABLE IF NOT EXISTS cataract_results ( id integer PRIMARY KEY, image blob, red_quantity real, green_quantity real, blue_quantity real, stage text ); """ try: cursor = conn.cursor() cursor.execute(create_table_sql) except sqlite3.Error as e: print(e) def predict_object_detection(image): try: image_np = np.array(image) results = yolo_model_object_detection(image_np) image_with_boxes = image_np.copy() raw_predictions = [] for result in results[0].boxes: label = "Normal" if result.cls.item() == 1 else "Cataract" confidence = result.conf.item() xmin, ymin, xmax, ymax = map(int, result.xyxy[0]) cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2) font_scale = 1.0 thickness = 2 text = f'{label} {confidence:.2f}' (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED) cv2.putText(image_with_boxes, text, (xmin, ymin - baseline), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]") raw_predictions_str = "\n".join(raw_predictions) # Convert image_with_boxes to PIL image and add watermark image_with_boxes_pil = Image.fromarray(image_with_boxes) image_with_boxes_pil_with_watermark = add_watermark(image_with_boxes_pil) return np.array(image_with_boxes_pil_with_watermark), raw_predictions_str except Exception as e: print("Error in object detection:", e) return np.zeros_like(image), str(e)