|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Any, Callable, Optional |
|
|
|
|
|
from ....extras.types import HFModel |
|
|
from ...trainer_plugins.distributed.accelerate import get_available_accelerator |
|
|
from .constants import DeviceType, KernelType |
|
|
|
|
|
|
|
|
class KernelRegistry: |
|
|
_instance: Optional['KernelRegistry'] = None |
|
|
_initialized: bool = False |
|
|
|
|
|
def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry': |
|
|
if cls._instance is None: |
|
|
cls._instance = super().__new__(cls) |
|
|
return cls._instance |
|
|
|
|
|
def __init__(self) -> None: |
|
|
if self._initialized: |
|
|
return |
|
|
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {} |
|
|
self._initialized = True |
|
|
|
|
|
def register( |
|
|
self, |
|
|
kernel_type: KernelType, |
|
|
device_type: DeviceType, |
|
|
kernel_impl: Optional[Callable[..., Any]] |
|
|
) -> None: |
|
|
"""Register a kernel implementation. |
|
|
|
|
|
Args: |
|
|
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION). |
|
|
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA). |
|
|
kernel_impl: the actual kernel function or class. |
|
|
""" |
|
|
if kernel_type not in self._registry: |
|
|
self._registry[kernel_type] = {} |
|
|
|
|
|
if device_type in self._registry[kernel_type]: |
|
|
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.") |
|
|
|
|
|
self._registry[kernel_type][device_type] = kernel_impl |
|
|
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.") |
|
|
|
|
|
def get_kernel( |
|
|
self, |
|
|
kernel_type: KernelType, |
|
|
device_type: DeviceType |
|
|
) -> Optional[Callable[..., Any]]: |
|
|
return self._registry.get(kernel_type, {}).get(device_type) |
|
|
|
|
|
|
|
|
KERNEL_REGISTRY = KernelRegistry() |
|
|
|
|
|
|
|
|
class MetaKernel(ABC): |
|
|
type: Optional[KernelType] = None |
|
|
device: Optional[DeviceType] = None |
|
|
kernel: Optional[Callable] = None |
|
|
|
|
|
@classmethod |
|
|
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType): |
|
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls) |
|
|
|
|
|
@classmethod |
|
|
@abstractmethod |
|
|
def apply(cls, model: HFModel, **kwargs) -> HFModel: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class MetaFlashAttentionKernel(MetaKernel): |
|
|
|
|
|
@classmethod |
|
|
def apply(cls, model: HFModel, **kwargs) -> HFModel: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class MetaRMSNormKernel(MetaKernel): |
|
|
|
|
|
@classmethod |
|
|
def apply(cls, model: HFModel, **kwargs) -> HFModel: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class MetaSwiGluKernel(MetaKernel): |
|
|
|
|
|
@classmethod |
|
|
def apply(cls, model: HFModel, **kwargs) -> HFModel: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class MetaRoPEKernel(MetaKernel): |
|
|
|
|
|
@classmethod |
|
|
def apply(cls, model: HFModel, **kwargs) -> HFModel: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class MetaMoEKernel(MetaKernel): |
|
|
|
|
|
@classmethod |
|
|
def apply(cls, model: HFModel, **kwargs) -> HFModel: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def discover_kernels(model: HFModel) -> list[MetaKernel]: |
|
|
"""Discover and construct MetaKernel instances for the current model/device. |
|
|
|
|
|
This is a placeholder to be implemented: it should inspect the runtime |
|
|
environment (device type, available extensions, model architecture) and |
|
|
return an ordered list of MetaKernel instances to be applied. Each returned |
|
|
MetaKernel must encapsulate its own replacement logic in `apply`. |
|
|
""" |
|
|
|
|
|
return [] |
|
|
|
|
|
|
|
|
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel': |
|
|
"""Call the MetaKernel's `apply` to perform the replacement. |
|
|
|
|
|
Corresponding replacement logic is maintained inside each kernel; the only |
|
|
requirement is that `apply` returns the replaced model. |
|
|
|
|
|
Example: |
|
|
from transformers import AutoModelForCausalLM |
|
|
from .rms_norm.npu_rms_norm import NpuRMSNormKernel |
|
|
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B") |
|
|
model = apply_kernel(model, NpuRMSNormKernel) |
|
|
""" |
|
|
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type: |
|
|
return kernel.apply(model, **kwargs) |
|
|
|
|
|
raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.") |
|
|
|