import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D


alpha=0.1

def plot_plane(ax, h, k, l, color, alpha, label):
    """
    Plots a plane defined by Miller indices (hkl) in a cubic system.
    The equation of the plane is hx + ky + lz = d. We'll use d=1 for visualization.
    """

    # Create a grid for coordinates
    # We use a larger range to ensure planes intersect the display volume
    grid_range = np.linspace(-1.5, 1.5, 50)

    if l != 0:
        # Plane is not parallel to the z-axis (i.e., intersects the z-axis)
        xx, yy = np.meshgrid(grid_range, grid_range)
        zz = (1 - h * xx - k * yy) / l
        # Filter out points outside the plotting range for better visualization
        zz[(zz < -1.5) | (zz > 1.5)] = np.nan
        ax.plot_surface(xx, yy, zz, color=color, alpha=alpha)
    elif k != 0 and h !=0: # Existing logic for planes like (110), (120), (1-20), (1-10)
        # Plane is parallel to the z-axis and not parallel to the y-axis
        # We need to create a mesh for x and z and solve for y
        xx_plane, zz_plane = np.meshgrid(grid_range, grid_range)
        yy_plane = (1 - h * xx_plane) / k
        yy_plane[(yy_plane < -1.5) | (yy_plane > 1.5)] = np.nan
        ax.plot_surface(xx_plane, yy_plane, zz_plane, color=color, alpha=alpha)
    elif h != 0 and k == 0 and l == 0: # Handle (100) specifically for x=0 plane
        yy_plane, zz_plane = np.meshgrid(grid_range, grid_range)
        xx_plane = np.zeros_like(yy_plane) # Set x to 0
        ax.plot_surface(xx_plane, yy_plane, zz_plane, color=color, alpha=alpha)
    elif h == 0 and k != 0 and l == 0: # Handle (010) specifically for y=0 plane
        xx_plane, zz_plane = np.meshgrid(grid_range, grid_range)
        yy_plane = np.zeros_like(xx_plane) # Set y to 0
        ax.plot_surface(xx_plane, yy_plane, zz_plane, color=color, alpha=alpha)
    elif h == 0 and k == 0 and l != 0: # Handle (001) specifically for z=0 plane
        xx_plane, yy_plane = np.meshgrid(grid_range, grid_range)
        zz_plane = np.zeros_like(xx_plane) # Set z to 0
        ax.plot_surface(xx_plane, yy_plane, zz_plane, color=color, alpha=alpha)
    else:
        print(f"Warning: Cannot plot plane ({h}{k}{l}) - all indices are zero or a single plane cannot be uniquely defined.")

    # Add plane label to the legend
    # This is a workaround as plot_surface does not directly support 'label' for legend
    # We plot an invisible line to create a legend entry
    ax.plot([], [], [], color=color, label=label, alpha=alpha)


# Initialize the 3D plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Define plotting limits for better visualization of planes and vectors
plot_limit = 1.5
ax.set_xlim([-plot_limit, plot_limit])
ax.set_ylim([-plot_limit, plot_limit])
ax.set_zlim([-plot_limit, plot_limit])

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('Cubic Crystal Planes with [001] Zone Axis')

# Plot the [001] Zone Axis
zone_axis = np.array([0, 0, 1])
ax.quiver(0, 0, 0, zone_axis[0], zone_axis[1], zone_axis[2],
          color='black', length=1.2, arrow_length_ratio=0.1)
ax.text(zone_axis[0] * 1.3, zone_axis[1] * 1.3, zone_axis[2] * 1.3, r'$[001]$ Zone Axis', color='black')

# Define planes (hkl), their colors, and labels
# Note: For the (1-1) plane, it is assumed to be (1-10) to fit the [001] zone axis
# as (1-10) satisfies the zone law with [001] (1*0 + (-1)*0 + 0*1 = 0)
planes_data = [
    {'hkl': (1, 1, 0), 'color': 'red', 'label': r'$(110)$'},
    {'hkl': (1, 2, 0), 'color': 'blue', 'label': r'$(120)$'},
    {'hkl': (0, 1, 0), 'color': 'green', 'label': r'$(010)$'},
    {'hkl': (1, -2, 0), 'color': 'purple', 'label': r'$(1\bar{2}0)$'},
    {'hkl': (1, -1, 0), 'color': 'orange', 'label': r'$(1\bar{1}0)$'} # Interpreted as (1-10)
]

for plane_info in planes_data:
    h, k, l = plane_info['hkl']
    color = plane_info['color']
    label = plane_info['label']

    plot_plane(ax, h, k, l, color, alpha=alpha, label=label)

    # Plot plane normal vector for cubic system. For cubic, [hkl] is normal to (hkl)
    normal_vector = np.array([h, k, l])
    if np.linalg.norm(normal_vector) > 0: # Avoid division by zero for [000]
        # Normalize and scale the vector for better visibility
        normal_vector_scaled = normal_vector / np.linalg.norm(normal_vector) * 0.8
        ax.quiver(0, 0, 0, normal_vector_scaled[0], normal_vector_scaled[1], normal_vector_scaled[2],
                  color=color, linestyle='--', length=0.8, arrow_length_ratio=0.2)
        # Add label for the normal vector
        ax.text(normal_vector_scaled[0] * 1.1, normal_vector_scaled[1] * 1.1, normal_vector_scaled[2] * 1.1,
                label + ' normal', color=color)

ax.legend()
plt.tight_layout()
plt.savefig('cubic_planes_001_zone.png')
plt.show() # To display the plot window as well