import pygame
import sys

def display_menu(screen):
    # Load images
    button1_img = pygame.image.load("assets/LevelsButton.png")
    button1_hover_img = pygame.image.load("assets/LevelsButtonHover.png")
    button2_img = pygame.image.load("assets/MenuButton.png")
    button2_hover_img = pygame.image.load("assets/MenuButtonHover.png")
    background_img = pygame.image.load("assets/background.png")

    # Get image dimensions
    button_width, button_height = button1_img.get_width(), button1_img.get_height()

    # Calculate button positions
    button1_x = (screen.get_width() - button_width) // 2
    button1_y = (screen.get_height() - button_height) // 2 - 50
    button2_x = (screen.get_width() - button_width) // 2
    button2_y = (screen.get_height() - button_height) // 2 + 50

    while True:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_ESCAPE:
                    return None  # Exit menu without choosing an option
            elif event.type == pygame.MOUSEBUTTONDOWN:
                if event.button == 1:  # Left mouse button
                    # Get mouse position
                    mouse_pos = pygame.mouse.get_pos()
                    # Check if mouse clicked on buttons
                    if button1_x <= mouse_pos[0] <= button1_x + button_width and \
                            button1_y <= mouse_pos[1] <= button1_y + button_height:
                        return "level_selector"
                    elif button2_x <= mouse_pos[0] <= button2_x + button_width and \
                            button2_y <= mouse_pos[1] <= button2_y + button_height:
                        return "titlescreen"

        # Display background
        screen.blit(background_img, (0, 0))

        # Render and display buttons
        mouse_pos = pygame.mouse.get_pos()
        if button1_x <= mouse_pos[0] <= button1_x + button_width and \
                button1_y <= mouse_pos[1] <= button1_y + button_height:
            screen.blit(button1_hover_img, (button1_x, button1_y))
        else:
            screen.blit(button1_img, (button1_x, button1_y))

        if button2_x <= mouse_pos[0] <= button2_x + button_width and \
                button2_y <= mouse_pos[1] <= button2_y + button_height:
            screen.blit(button2_hover_img, (button2_x, button2_y))
        else:
            screen.blit(button2_img, (button2_x, button2_y))

        # Update the display
        pygame.display.flip()