csuhan's picture
Upload folder using huggingface_hub
b0c0df0 verified
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional
from ....extras.types import HFModel
from ...trainer_plugins.distributed.accelerate import get_available_accelerator
from .constants import DeviceType, KernelType
class KernelRegistry:
_instance: Optional['KernelRegistry'] = None
_initialized: bool = False
def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry':
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._registry: dict[KernelType, dict[DeviceType, Callable[..., Any]]] = {}
self._initialized = True
def register(
self,
kernel_type: KernelType,
device_type: DeviceType,
kernel_impl: Optional[Callable[..., Any]]
) -> None:
"""Register a kernel implementation.
Args:
kernel_type: the type of the kernel (e.g., KernelType.FLASH_ATTENTION).
device_type: the device type the kernel is adapted to (e.g., DeviceType.CUDA).
kernel_impl: the actual kernel function or class.
"""
if kernel_type not in self._registry:
self._registry[kernel_type] = {}
if device_type in self._registry[kernel_type]:
print(f"Warning: Overwriting kernel for {kernel_type.name} on {device_type.name}.")
self._registry[kernel_type][device_type] = kernel_impl
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
def get_kernel(
self,
kernel_type: KernelType,
device_type: DeviceType
) -> Optional[Callable[..., Any]]:
return self._registry.get(kernel_type, {}).get(device_type)
KERNEL_REGISTRY = KernelRegistry()
class MetaKernel(ABC):
type: Optional[KernelType] = None
device: Optional[DeviceType] = None
kernel: Optional[Callable] = None
@classmethod
def register_kernel(cls, kernel_type: KernelType, device_type: DeviceType):
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
@classmethod
@abstractmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaFlashAttentionKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaRMSNormKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaSwiGluKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaRoPEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
class MetaMoEKernel(MetaKernel):
@classmethod
def apply(cls, model: HFModel, **kwargs) -> HFModel:
raise NotImplementedError
def discover_kernels(model: HFModel) -> list[MetaKernel]:
"""Discover and construct MetaKernel instances for the current model/device.
This is a placeholder to be implemented: it should inspect the runtime
environment (device type, available extensions, model architecture) and
return an ordered list of MetaKernel instances to be applied. Each returned
MetaKernel must encapsulate its own replacement logic in `apply`.
"""
# TODO: Implement auto discovery logic based on registry and device capabilities.
return []
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel':
"""Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only
requirement is that `apply` returns the replaced model.
Example:
from transformers import AutoModelForCausalLM
from .rms_norm.npu_rms_norm import NpuRMSNormKernel
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
model = apply_kernel(model, NpuRMSNormKernel)
"""
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
return kernel.apply(model, **kwargs)
raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.")