from ev import Job, Env, Volume
from typing import List, Dict
import logging
import daft
import numpy as np
import torch
from daft import col
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
import os
from datetime import datetime
# Configure environment with ML dependencies
env = Env().pip_install([
"torch==2.0.0",
"torchvision==0.15.0",
"lance==0.8.0",
])
# Define data volume for images
image_volume = Volume(
name="product-images",
path="s3://company-images/products/"
)
# Define ResNet model UDF for distributed inference
@daft.udf(
return_dtype=daft.DataType.list(dtype=daft.DataType.float32()),
)
class ResNetModel:
def __init__(self):
weights = ResNet50_Weights.DEFAULT
self.model = resnet50(weights=weights)
self.model.eval()
# Use GPU if available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
logging.info(f"ResNet model initialized on {self.device}")
def __call__(self, images):
if len(images) == 0:
return []
with torch.inference_mode():
tensor = torch.as_tensor(
np.array(images.to_pylist()),
device=self.device
)
result = self.model(tensor)
cpu_result = result.cpu()
numpy_result = cpu_result.numpy()
return numpy_result.tolist()
# Create the job
job = Job()
@job.main(env=env)
def process_images(
input_prefix: str = "raw/",
output_prefix: str = "processed/",
max_images: int = 10000,
data_volume: Volume = image_volume
):
"""
Process images with feature extraction using daft and Ray.
Args:
input_prefix: S3 prefix for input images
output_prefix: S3 prefix for processed results
max_images: Maximum number of images to process
data_volume: Volume containing image data
Returns:
Dictionary with processing statistics
"""
logging.info(f"Starting image processing job with daft")
logging.info(f"Input: {data_volume.path}{input_prefix}")
logging.info(f"Output: {data_volume.path}{output_prefix}")
# Load image metadata
metadata_path = data_volume.path + input_prefix + "metadata.parquet"
df = daft.read_parquet(metadata_path)
# Limit processing if specified
if max_images > 0:
df = df.limit(max_images)
# Create image URLs from metadata
# Assumes metadata has 'folder' and 'filename' columns
df = df.with_column(
"image_url",
data_volume.path + input_prefix + "images/" +
df["folder"] + "/" + df["filename"] + ".jpeg"
)
# Download and decode images
logging.info("Downloading and decoding images...")
df = df.with_column(
"image",
df["image_url"]
.url.download(on_error="null")
.image.decode(on_error="null", mode=daft.ImageMode.RGB),
)
# Filter out failed downloads
df = df.drop_null("image")
# Define image preprocessing transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Apply transforms
logging.info("Applying image transforms...")
df = df.with_column(
"norm_image",
df["image"].apply(
func=lambda img: transform(img),
return_dtype=daft.DataType.tensor(
dtype=daft.DataType.float32(), shape=(3, 224, 224)
),
),
)
# Extract features using ResNet model
logging.info("Extracting features with ResNet model...")
df = df.with_column(
"features",
ResNetModel(col("norm_image"))
)
# Add processing metadata
df = df.with_columns({
"processed_at": daft.lit(datetime.now().isoformat()),
"batch_size": daft.lit(batch_size),
})
# Clean up columns for output
df_final = df.select([
"image_url", "folder", "filename", "features", "processed_at", "batch_size"
])
# Save results to Lance format
output_path = data_volume.path + output_prefix + "features.lance"
logging.info("Saving results to Lance format...")
df_final.write_lance(output_path, mode="overwrite")
# Calculate statistics
total_images = len(df_final.collect())
stats = {
"total_images": total_images,
"output_path": output_path,
"batch_size": batch_size,
"concurrency": concurrency,
"processing_time": "computed_by_platform"
}
logging.info(f"Processing complete: {stats}")
return stats