{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "5d81bb13", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"facebook/natural_reasoning\")\n", "train_data = dataset[\"train\"].select(range(5000)) # Start with 5k examples\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "5279c3c3", "metadata": {}, "outputs": [], "source": [ "def format_for_training(example):\n", " return {\n", " \"prompt\": example[\"question\"],\n", " \"completion\": example[\"reference_answer\"]\n", " }\n", "\n", "train_data = train_data.map(format_for_training)\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "d5f715b3", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "model_checkpoint = \"distilgpt2\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", "tokenizer.pad_token = tokenizer.eos_token\n", "max_seq_length = 512\n", "\n", "def tokenize(example):\n", " input_text = f\"### Question: {example['prompt']}\\n### Answer: {example['completion']}{tokenizer.eos_token}\"\n", " tokenized = tokenizer(\n", " input_text,\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=max_seq_length\n", " )\n", " tokenized[\"labels\"] = tokenized[\"input_ids\"].copy()\n", " return tokenized\n", "\n", "tokenized_data = train_data.map(tokenize)\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "61cb619d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\shukl\\anaconda3\\Lib\\site-packages\\transformers\\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n", "C:\\Users\\shukl\\AppData\\Local\\Temp\\ipykernel_7600\\3538093026.py:16: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", " trainer = Trainer(\n", "`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
500 | \n", "0.836400 | \n", "
1000 | \n", "0.629200 | \n", "
1500 | \n", "0.631400 | \n", "
2000 | \n", "0.622300 | \n", "
2500 | \n", "0.631600 | \n", "
"
],
"text/plain": [
"