Speeding up PyTorch inference by 87% on Apple with AI-generated Metal kernels

https://news.ycombinator.com/rss Hits: 19
Summary

1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4from einops import rearrange 5 6# Safe wrappers: try to build the Metal extension, but always provide PyTorch fallbacks. 7mpskern = None 8_have_native_kernels = False 9try: 10 from torch.utils.cpp_extension import load_inline 11 12 cpp_source = r''' 13 #include <torch/extension.h> 14 #import <Foundation/Foundation.h> 15 #import <Metal/Metal.h> 16 17 static const char *METAL_SRC = R"KERNEL( 18 #include <metal_stdlib> 19 using namespace metal; 20 21 // Compute exp(segsum) lower triangular matrix from cumsum prefix for 4D case: 22 // prefix shape: [num_vec, L] 23 // output shape: [num_vec, L, L] 24 // value(i,j) = j <= i ? exp(prefix[i] - prefix[j]) : 0 25 kernel void lower_tri_from_prefix_4d(constant float* prefix [[buffer(0)]], 26 device float* out [[buffer(1)]], 27 constant uint* params [[buffer(2)]], 28 uint index [[thread_position_in_grid]]) { 29 uint num_vec = params[0]; 30 uint L = params[1]; 31 uint total = num_vec * L * L; 32 if (index >= total) return; 33 34 uint vecId = index / (L * L); 35 uint rem = index - vecId * (L * L); 36 uint i = rem / L; 37 uint j = rem - i * L; 38 39 if (j <= i) { 40 float vi = prefix[vecId * L + i]; 41 float vj = prefix[vecId * L + j]; 42 out[vecId * (L * L) + i * L + j] = exp(vi - vj); 43 } else { 44 out[vecId * (L * L) + i * L + j] = 0.0f; 45 } 46 } 47 48 // Same as above for 3D prefix: prefix shape [num_vec, Z], output [num_vec, Z, Z] 49 kernel void lower_tri_from_prefix_3d(constant float* prefix [[buffer(0)]], 50 device float* out [[buffer(1)]], 51 constant uint* params [[buffer(2)]], 52 uint index [[thread_position_in_grid]]) { 53 uint num_vec = params[0]; 54 uint Z = params[1]; 55 uint total = num_vec * Z * Z; 56 if (index >= total) return; 57 58 uint vecId = index / (Z * Z); 59 uint rem = index - vecId * (Z * Z); 60 uint i = rem / Z; 61 uint j = rem - i * Z; 62 63 if (j <= i) { 64 float vi = prefix[vecId * Z + i]; 65 float vj = prefix[vecId * Z + j]; 66 o...

First seen: 2025-09-03 17:56

Last seen: 2025-09-04 12:00