Source code for analytic_continuation_server.routes.domaincolor

"""
Domain coloring API endpoints.

Provides routes for:
- Domain coloring image generation
- Expression validation
- Validation renders with overlays
"""

import io
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from ..models import DomainColorRequest, ValidationRenderRequest

router = APIRouter(prefix="/api/domaincolor", tags=["domaincolor"])


[docs] @router.post("") async def generate_domain_color(request: DomainColorRequest): """ Generate a domain coloring image. Returns a PNG image as a streaming response. """ try: # Import py_domaincolor (optional dependency) from py_domaincolor import domain_color_array, get_callable import numpy as np from PIL import Image y_range = request.y_range or request.x_range # Get the function f = get_callable(request.expression) # Generate the domain coloring array img_array = domain_color_array( f, x_range=request.x_range, y_range=y_range, resolution=request.resolution, mode=request.mode, mod_contours=request.mod_contours, arg_contours=request.arg_contours, ) # Convert to PIL Image img = Image.fromarray((img_array * 255).astype(np.uint8)) # Save to bytes buffer buffer = io.BytesIO() img.save(buffer, format="PNG") buffer.seek(0) return StreamingResponse(buffer, media_type="image/png") except ImportError: raise HTTPException( status_code=501, detail="py_domaincolor package not installed. Install with: pip install py-domaincolor" ) except Exception as e: raise HTTPException(status_code=400, detail=str(e))
[docs] @router.get("/validate") async def validate_expression(expression: str): """Validate a complex function expression without generating an image.""" try: from py_domaincolor import validate_expression is_valid, error = validate_expression(expression) return {"valid": is_valid, "error": error} except ImportError: raise HTTPException( status_code=501, detail="py_domaincolor package not installed" )
[docs] @router.post("/validation-render") async def validation_render(request: ValidationRenderRequest): """ Generate a Python-rendered domain coloring with zeros/poles overlay. This serves as ground truth for validating the WebGL renderer. Returns a PNG image with: - Domain coloring of the meromorphic function - Zero markers (green circles) - Pole markers (red X's) - Optional curve overlay - Axis tick marks on border """ try: import numpy as np from PIL import Image, ImageDraw from py_domaincolor import get_callable, complex_to_rgb, create_domain_grid y_range = request.y_range or request.x_range # Build the meromorphic function from zeros and poles if request.zeros or request.poles: numerator_terms = [] for z in request.zeros: for _ in range(z.multiplicity): if z.y >= 0: numerator_terms.append(f"(z - ({z.x}+{z.y}*I))") else: numerator_terms.append(f"(z - ({z.x}{z.y}*I))") denominator_terms = [] for p in request.poles: for _ in range(p.multiplicity): if p.y >= 0: denominator_terms.append(f"(z - ({p.x}+{p.y}*I))") else: denominator_terms.append(f"(z - ({p.x}{p.y}*I))") numerator = "*".join(numerator_terms) if numerator_terms else "1" denominator = "*".join(denominator_terms) if denominator_terms else "1" if request.expression and request.expression.strip() != "1": full_expr = f"({request.expression}) * {numerator} / ({denominator})" else: full_expr = f"{numerator} / ({denominator})" else: full_expr = request.expression # Get the callable function f = get_callable(full_expr) # Create the grid and evaluate X, Y, Z = create_domain_grid(request.x_range, y_range, request.resolution) with np.errstate(all='ignore'): W = f(Z) # Generate domain coloring rgb = complex_to_rgb( W, mode='standard', mod_contours=request.mod_contours, arg_contours_flag=request.arg_contours ) # Convert to PIL Image (flip Y for standard orientation) img_array = (rgb * 255).astype(np.uint8) img_array = np.flipud(img_array) img = Image.fromarray(img_array) # Draw overlays draw = ImageDraw.Draw(img) # Helper to convert complex coords to pixel coords def complex_to_pixel(re, im): px = int((re - request.x_range[0]) / (request.x_range[1] - request.x_range[0]) * request.resolution) py = int((y_range[1] - im) / (y_range[1] - y_range[0]) * request.resolution) return px, py # Draw curve if provided if request.curve_points and len(request.curve_points) >= 2: curve_pixels = [complex_to_pixel(p['re'], p['im']) for p in request.curve_points] for i in range(len(curve_pixels) - 1): draw.line( [curve_pixels[i], curve_pixels[i + 1]], fill=request.curve_color, width=int(request.curve_width) ) # Draw zeros (green circles) for z in request.zeros: px, py = complex_to_pixel(z.x, z.y) r = 8 + 2 * (z.multiplicity - 1) draw.ellipse([px - r, py - r, px + r, py + r], outline='#00ff88', width=2) if z.multiplicity > 1: draw.text((px + r + 2, py - 6), str(z.multiplicity), fill='#00ff88') # Draw poles (red X's) for p in request.poles: px, py = complex_to_pixel(p.x, p.y) r = 8 + 2 * (p.multiplicity - 1) draw.line([px - r, py - r, px + r, py + r], fill='#ff4444', width=2) draw.line([px + r, py - r, px - r, py + r], fill='#ff4444', width=2) if p.multiplicity > 1: draw.text((px + r + 2, py - 6), str(p.multiplicity), fill='#ff4444') # Draw axis tick marks on border if request.show_axis_ticks: tick_len = 10 # Real axis ticks x_start = np.ceil(request.x_range[0] / request.tick_interval) * request.tick_interval x_end = np.floor(request.x_range[1] / request.tick_interval) * request.tick_interval for x_val in np.arange(x_start, x_end + 0.001, request.tick_interval): px, _ = complex_to_pixel(x_val, 0) draw.line([px, request.resolution - 1, px, request.resolution - tick_len], fill='white', width=1) draw.line([px, 0, px, tick_len], fill='white', width=1) label = f"{x_val:.0f}" if x_val == int(x_val) else f"{x_val:.1f}" draw.text((px - 8, request.resolution - tick_len - 12), label, fill='white') # Imaginary axis ticks y_start = np.ceil(y_range[0] / request.tick_interval) * request.tick_interval y_end = np.floor(y_range[1] / request.tick_interval) * request.tick_interval for y_val in np.arange(y_start, y_end + 0.001, request.tick_interval): _, py = complex_to_pixel(0, y_val) draw.line([0, py, tick_len, py], fill='white', width=1) draw.line([request.resolution - 1, py, request.resolution - tick_len, py], fill='white', width=1) label = f"{y_val:.0f}i" if y_val == int(y_val) else f"{y_val:.1f}i" draw.text((tick_len + 2, py - 6), label, fill='white') # Save to bytes buffer buffer = io.BytesIO() img.save(buffer, format="PNG") buffer.seek(0) return StreamingResponse(buffer, media_type="image/png") except ImportError as e: raise HTTPException(status_code=501, detail=f"Required package not installed: {e}") except Exception as e: import traceback raise HTTPException(status_code=400, detail=f"{str(e)}\n{traceback.format_exc()}")