Commit
·
472e2e9
0
Parent(s):
Sure! Pl
Browse files- api_service.py +614 -0
- data/.gitkeep +2 -0
- evaluation/.gitkeep +2 -0
- models/.gitkeep +2 -0
- raw_dataset.json +0 -0
- requirements.txt +12 -0
- setup-guide.md +342 -0
- test_api.py +160 -0
- training_pipeline.py +772 -0
api_service.py
ADDED
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Backend Code Generation API Service
|
4 |
+
===================================
|
5 |
+
|
6 |
+
Production-ready API service for serving the trained backend code generation model.
|
7 |
+
Provides RESTful endpoints for generating complete backend applications.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
|
11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
12 |
+
from fastapi.responses import StreamingResponse, FileResponse
|
13 |
+
from pydantic import BaseModel, Field
|
14 |
+
from typing import List, Dict, Optional, Any
|
15 |
+
import torch
|
16 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
17 |
+
import json
|
18 |
+
import zipfile
|
19 |
+
import tempfile
|
20 |
+
import os
|
21 |
+
import uuid
|
22 |
+
from datetime import datetime
|
23 |
+
import asyncio
|
24 |
+
import logging
|
25 |
+
from pathlib import Path
|
26 |
+
|
27 |
+
# Configure logging
|
28 |
+
logging.basicConfig(level=logging.INFO)
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
# Pydantic models for API
|
32 |
+
class CodeGenerationRequest(BaseModel):
|
33 |
+
description: str = Field(..., description="Description of the backend application to generate")
|
34 |
+
framework: str = Field(..., description="Target framework (express, fastapi, django, flask)")
|
35 |
+
language: str = Field(..., description="Programming language (javascript, python)")
|
36 |
+
requirements: List[str] = Field(default=[], description="List of specific requirements")
|
37 |
+
project_name: Optional[str] = Field(default=None, description="Custom project name")
|
38 |
+
|
39 |
+
class Config:
|
40 |
+
schema_extra = {
|
41 |
+
"example": {
|
42 |
+
"description": "E-commerce API with user authentication and product management",
|
43 |
+
"framework": "fastapi",
|
44 |
+
"language": "python",
|
45 |
+
"requirements": [
|
46 |
+
"User registration and login",
|
47 |
+
"JWT authentication",
|
48 |
+
"Product CRUD operations",
|
49 |
+
"Shopping cart functionality",
|
50 |
+
"Order management"
|
51 |
+
],
|
52 |
+
"project_name": "ecommerce-api"
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
+
class GenerationResponse(BaseModel):
|
57 |
+
task_id: str
|
58 |
+
status: str
|
59 |
+
message: str
|
60 |
+
estimated_time: int
|
61 |
+
|
62 |
+
class GenerationStatus(BaseModel):
|
63 |
+
task_id: str
|
64 |
+
status: str # pending, processing, completed, failed
|
65 |
+
progress: int # 0-100
|
66 |
+
message: str
|
67 |
+
generated_files: Optional[Dict[str, str]] = None
|
68 |
+
download_url: Optional[str] = None
|
69 |
+
error: Optional[str] = None
|
70 |
+
|
71 |
+
class GeneratedProject(BaseModel):
|
72 |
+
project_name: str
|
73 |
+
framework: str
|
74 |
+
language: str
|
75 |
+
files: Dict[str, str]
|
76 |
+
structure: Dict[str, Any]
|
77 |
+
setup_instructions: List[str]
|
78 |
+
features: List[str]
|
79 |
+
|
80 |
+
# Global model instance
|
81 |
+
class ModelManager:
|
82 |
+
def __init__(self):
|
83 |
+
self.model = None
|
84 |
+
self.tokenizer = None
|
85 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
86 |
+
self.loaded = False
|
87 |
+
|
88 |
+
async def load_model(self, model_path: str = "./trained_model"):
|
89 |
+
"""Load the trained model asynchronously"""
|
90 |
+
try:
|
91 |
+
logger.info(f"Loading model from {model_path} on {self.device}")
|
92 |
+
|
93 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
94 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
95 |
+
model_path,
|
96 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
97 |
+
device_map="auto" if self.device == "cuda" else None
|
98 |
+
)
|
99 |
+
|
100 |
+
if self.device == "cpu":
|
101 |
+
self.model = self.model.to(self.device)
|
102 |
+
|
103 |
+
self.loaded = True
|
104 |
+
logger.info("Model loaded successfully!")
|
105 |
+
|
106 |
+
except Exception as e:
|
107 |
+
logger.error(f"Failed to load model: {e}")
|
108 |
+
raise
|
109 |
+
|
110 |
+
def generate_code(self, prompt: str, max_tokens: int = 1024) -> str:
|
111 |
+
"""Generate code using the trained model"""
|
112 |
+
if not self.loaded:
|
113 |
+
raise RuntimeError("Model not loaded")
|
114 |
+
|
115 |
+
inputs = self.tokenizer.encode(prompt, return_tensors='pt')
|
116 |
+
inputs = inputs.to(self.device)
|
117 |
+
|
118 |
+
with torch.no_grad():
|
119 |
+
outputs = self.model.generate(
|
120 |
+
inputs,
|
121 |
+
max_length=min(max_tokens, 1024),
|
122 |
+
num_return_sequences=1,
|
123 |
+
temperature=0.7,
|
124 |
+
do_sample=True,
|
125 |
+
top_p=0.9,
|
126 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
127 |
+
repetition_penalty=1.1
|
128 |
+
)
|
129 |
+
|
130 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
131 |
+
return generated_text[len(self.tokenizer.decode(inputs[0], skip_special_tokens=True)):]
|
132 |
+
|
133 |
+
# Global instances
|
134 |
+
model_manager = ModelManager()
|
135 |
+
generation_tasks = {} # Store generation tasks
|
136 |
+
|
137 |
+
# FastAPI app
|
138 |
+
app = FastAPI(
|
139 |
+
title="Backend Code Generation API",
|
140 |
+
description="AI-powered backend application generator",
|
141 |
+
version="1.0.0",
|
142 |
+
docs_url="/docs",
|
143 |
+
redoc_url="/redoc"
|
144 |
+
)
|
145 |
+
|
146 |
+
# CORS middleware
|
147 |
+
app.add_middleware(
|
148 |
+
CORSMiddleware,
|
149 |
+
allow_origins=["*"], # Configure for production
|
150 |
+
allow_credentials=True,
|
151 |
+
allow_methods=["*"],
|
152 |
+
allow_headers=["*"],
|
153 |
+
)
|
154 |
+
|
155 |
+
@app.on_event("startup")
|
156 |
+
async def startup_event():
|
157 |
+
"""Load model on startup"""
|
158 |
+
model_path = os.getenv("MODEL_PATH", "./trained_model")
|
159 |
+
await model_manager.load_model(model_path)
|
160 |
+
|
161 |
+
@app.get("/")
|
162 |
+
async def root():
|
163 |
+
"""API root endpoint"""
|
164 |
+
return {
|
165 |
+
"service": "Backend Code Generation API",
|
166 |
+
"version": "1.0.0",
|
167 |
+
"status": "running",
|
168 |
+
"model_loaded": model_manager.loaded,
|
169 |
+
"endpoints": {
|
170 |
+
"generate": "/api/v1/generate",
|
171 |
+
"status": "/api/v1/status/{task_id}",
|
172 |
+
"download": "/api/v1/download/{task_id}",
|
173 |
+
"health": "/health"
|
174 |
+
}
|
175 |
+
}
|
176 |
+
|
177 |
+
@app.get("/health")
|
178 |
+
async def health_check():
|
179 |
+
"""Health check endpoint"""
|
180 |
+
return {
|
181 |
+
"status": "OK",
|
182 |
+
"timestamp": datetime.utcnow().isoformat(),
|
183 |
+
"model_loaded": model_manager.loaded,
|
184 |
+
"device": model_manager.device if model_manager.loaded else None
|
185 |
+
}
|
186 |
+
|
187 |
+
@app.post("/api/v1/generate", response_model=GenerationResponse)
|
188 |
+
async def generate_backend(
|
189 |
+
request: CodeGenerationRequest,
|
190 |
+
background_tasks: BackgroundTasks
|
191 |
+
):
|
192 |
+
"""Generate a complete backend application"""
|
193 |
+
|
194 |
+
if not model_manager.loaded:
|
195 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
196 |
+
|
197 |
+
# Create unique task ID
|
198 |
+
task_id = str(uuid.uuid4())
|
199 |
+
|
200 |
+
# Initialize task status
|
201 |
+
generation_tasks[task_id] = GenerationStatus(
|
202 |
+
task_id=task_id,
|
203 |
+
status="pending",
|
204 |
+
progress=0,
|
205 |
+
message="Task queued for processing"
|
206 |
+
)
|
207 |
+
|
208 |
+
# Start background generation
|
209 |
+
background_tasks.add_task(
|
210 |
+
generate_project_background,
|
211 |
+
task_id,
|
212 |
+
request
|
213 |
+
)
|
214 |
+
|
215 |
+
return GenerationResponse(
|
216 |
+
task_id=task_id,
|
217 |
+
status="accepted",
|
218 |
+
message="Code generation started",
|
219 |
+
estimated_time=60 # seconds
|
220 |
+
)
|
221 |
+
|
222 |
+
@app.get("/api/v1/status/{task_id}", response_model=GenerationStatus)
|
223 |
+
async def get_generation_status(task_id: str):
|
224 |
+
"""Get the status of a generation task"""
|
225 |
+
|
226 |
+
if task_id not in generation_tasks:
|
227 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
228 |
+
|
229 |
+
return generation_tasks[task_id]
|
230 |
+
|
231 |
+
@app.get("/api/v1/download/{task_id}")
|
232 |
+
async def download_generated_project(task_id: str):
|
233 |
+
"""Download the generated project as a ZIP file"""
|
234 |
+
|
235 |
+
if task_id not in generation_tasks:
|
236 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
237 |
+
|
238 |
+
task = generation_tasks[task_id]
|
239 |
+
|
240 |
+
if task.status != "completed":
|
241 |
+
raise HTTPException(status_code=400, detail="Generation not completed")
|
242 |
+
|
243 |
+
if not task.download_url:
|
244 |
+
raise HTTPException(status_code=404, detail="Download file not available")
|
245 |
+
|
246 |
+
if not os.path.exists(task.download_url):
|
247 |
+
raise HTTPException(status_code=404, detail="Download file not found")
|
248 |
+
|
249 |
+
return FileResponse(
|
250 |
+
path=task.download_url,
|
251 |
+
filename=f"generated_project_{task_id}.zip",
|
252 |
+
media_type="application/zip"
|
253 |
+
)
|
254 |
+
|
255 |
+
@app.delete("/api/v1/cleanup/{task_id}")
|
256 |
+
async def cleanup_task(task_id: str):
|
257 |
+
"""Clean up task files and data"""
|
258 |
+
|
259 |
+
if task_id not in generation_tasks:
|
260 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
261 |
+
|
262 |
+
task = generation_tasks[task_id]
|
263 |
+
|
264 |
+
# Remove download file if exists
|
265 |
+
if task.download_url and os.path.exists(task.download_url):
|
266 |
+
os.remove(task.download_url)
|
267 |
+
|
268 |
+
# Remove task from memory
|
269 |
+
del generation_tasks[task_id]
|
270 |
+
|
271 |
+
return {"message": "Task cleaned up successfully"}
|
272 |
+
|
273 |
+
async def generate_project_background(task_id: str, request: CodeGenerationRequest):
|
274 |
+
"""Background task for generating the complete project"""
|
275 |
+
|
276 |
+
task = generation_tasks[task_id]
|
277 |
+
|
278 |
+
try:
|
279 |
+
# Update status
|
280 |
+
task.status = "processing"
|
281 |
+
task.progress = 10
|
282 |
+
task.message = "Analyzing requirements..."
|
283 |
+
|
284 |
+
# Create the generation prompt
|
285 |
+
prompt = create_generation_prompt(request)
|
286 |
+
|
287 |
+
# Update progress
|
288 |
+
task.progress = 30
|
289 |
+
task.message = "Generating application structure..."
|
290 |
+
|
291 |
+
# Generate code using the model
|
292 |
+
generated_code = model_manager.generate_code(prompt, max_tokens=1024)
|
293 |
+
|
294 |
+
# Update progress
|
295 |
+
task.progress = 60
|
296 |
+
task.message = "Processing generated code..."
|
297 |
+
|
298 |
+
# Parse and structure the generated code
|
299 |
+
project_files = parse_generated_code(generated_code, request)
|
300 |
+
|
301 |
+
# Update progress
|
302 |
+
task.progress = 80
|
303 |
+
task.message = "Creating project files..."
|
304 |
+
|
305 |
+
# Create downloadable ZIP file
|
306 |
+
zip_path = create_project_zip(task_id, project_files, request)
|
307 |
+
|
308 |
+
# Complete the task
|
309 |
+
task.status = "completed"
|
310 |
+
task.progress = 100
|
311 |
+
task.message = "Project generated successfully"
|
312 |
+
task.generated_files = {name: "Generated" for name in project_files.keys()}
|
313 |
+
task.download_url = zip_path
|
314 |
+
|
315 |
+
except Exception as e:
|
316 |
+
logger.error(f"Generation failed for task {task_id}: {e}")
|
317 |
+
task.status = "failed"
|
318 |
+
task.error = str(e)
|
319 |
+
task.message = "Generation failed"
|
320 |
+
|
321 |
+
def create_generation_prompt(request: CodeGenerationRequest) -> str:
|
322 |
+
"""Create the prompt for the model"""
|
323 |
+
|
324 |
+
prompt_parts = [
|
325 |
+
f"Description: {request.description}",
|
326 |
+
f"Framework: {request.framework}",
|
327 |
+
f"Language: {request.language}",
|
328 |
+
]
|
329 |
+
|
330 |
+
if request.requirements:
|
331 |
+
prompt_parts.append("Requirements:")
|
332 |
+
for req in request.requirements:
|
333 |
+
prompt_parts.append(f"- {req}")
|
334 |
+
|
335 |
+
if request.project_name:
|
336 |
+
prompt_parts.append(f"Project Name: {request.project_name}")
|
337 |
+
|
338 |
+
prompt_parts.append("Generate the complete backend application with all necessary files:")
|
339 |
+
|
340 |
+
return "\n".join(prompt_parts)
|
341 |
+
|
342 |
+
def parse_generated_code(generated_code: str, request: CodeGenerationRequest) -> Dict[str, str]:
|
343 |
+
"""Parse the generated code into individual files"""
|
344 |
+
|
345 |
+
files = {}
|
346 |
+
|
347 |
+
# Simple parsing logic - in production, this should be more sophisticated
|
348 |
+
lines = generated_code.split('\n')
|
349 |
+
current_file = None
|
350 |
+
current_content = []
|
351 |
+
|
352 |
+
for line in lines:
|
353 |
+
if line.startswith('--- ') and line.endswith(' ---'):
|
354 |
+
# Save previous file
|
355 |
+
if current_file:
|
356 |
+
files[current_file] = '\n'.join(current_content)
|
357 |
+
|
358 |
+
# Start new file
|
359 |
+
current_file = line.replace('--- ', '').replace(' ---', '').strip()
|
360 |
+
current_content = []
|
361 |
+
|
362 |
+
elif current_file and not line.startswith('--- End ---'):
|
363 |
+
current_content.append(line)
|
364 |
+
|
365 |
+
# Save last file
|
366 |
+
if current_file and current_content:
|
367 |
+
files[current_file] = '\n'.join(current_content)
|
368 |
+
|
369 |
+
# If parsing failed, create basic structure based on framework
|
370 |
+
if not files:
|
371 |
+
files = create_fallback_structure(request)
|
372 |
+
|
373 |
+
return files
|
374 |
+
|
375 |
+
def create_fallback_structure(request: CodeGenerationRequest) -> Dict[str, str]:
|
376 |
+
"""Create a basic project structure if parsing fails"""
|
377 |
+
|
378 |
+
if request.framework.lower() == 'fastapi':
|
379 |
+
return {
|
380 |
+
'main.py': f'''from fastapi import FastAPI
|
381 |
+
|
382 |
+
app = FastAPI(title="{request.description}")
|
383 |
+
|
384 |
+
@app.get("/")
|
385 |
+
async def root():
|
386 |
+
return {{"message": "Hello from {request.description}"}}
|
387 |
+
|
388 |
+
@app.get("/health")
|
389 |
+
async def health():
|
390 |
+
return {{"status": "OK"}}
|
391 |
+
''',
|
392 |
+
'requirements.txt': '''fastapi==0.104.1
|
393 |
+
uvicorn[standard]==0.24.0'''
|
394 |
+
}
|
395 |
+
|
396 |
+
elif request.framework.lower() == 'express':
|
397 |
+
return {
|
398 |
+
'app.js': f'''const express = require('express');
|
399 |
+
const app = express();
|
400 |
+
|
401 |
+
app.get('/', (req, res) => {{
|
402 |
+
res.json({{ message: 'Hello from {request.description}' }});
|
403 |
+
}});
|
404 |
+
|
405 |
+
app.get('/health', (req, res) => {{
|
406 |
+
res.json({{ status: 'OK' }});
|
407 |
+
}});
|
408 |
+
|
409 |
+
const PORT = process.env.PORT || 3000;
|
410 |
+
app.listen(PORT, () => {{
|
411 |
+
console.log(`Server running on port ${{PORT}}`);
|
412 |
+
}});
|
413 |
+
''',
|
414 |
+
'package.json': json.dumps({
|
415 |
+
"name": request.project_name or "generated-backend",
|
416 |
+
"version": "1.0.0",
|
417 |
+
"main": "app.js",
|
418 |
+
"dependencies": {
|
419 |
+
"express": "^4.18.2"
|
420 |
+
}
|
421 |
+
}, indent=2)
|
422 |
+
}
|
423 |
+
|
424 |
+
else:
|
425 |
+
return {
|
426 |
+
'README.md': f'# {request.description}\n\nGenerated backend application using {request.framework}'
|
427 |
+
}
|
428 |
+
|
429 |
+
def create_project_zip(task_id: str, files: Dict[str, str], request: CodeGenerationRequest) -> str:
|
430 |
+
"""Create a ZIP file containing all project files"""
|
431 |
+
|
432 |
+
# Create temporary directory for the ZIP file
|
433 |
+
temp_dir = tempfile.gettempdir()
|
434 |
+
zip_path = os.path.join(temp_dir, f"project_{task_id}.zip")
|
435 |
+
|
436 |
+
project_name = request.project_name or f"generated_{request.framework}_app"
|
437 |
+
|
438 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
439 |
+
for filename, content in files.items():
|
440 |
+
# Add each file to the ZIP
|
441 |
+
arcname = f"{project_name}/{filename}"
|
442 |
+
zipf.writestr(arcname, content)
|
443 |
+
|
444 |
+
# Add a README with setup instructions
|
445 |
+
setup_instructions = get_setup_instructions(request.framework)
|
446 |
+
zipf.writestr(f"{project_name}/SETUP.md", setup_instructions)
|
447 |
+
|
448 |
+
return zip_path
|
449 |
+
|
450 |
+
def get_setup_instructions(framework: str) -> str:
|
451 |
+
"""Get setup instructions for the framework"""
|
452 |
+
|
453 |
+
instructions = {
|
454 |
+
'fastapi': '''# Setup Instructions
|
455 |
+
|
456 |
+
1. Install dependencies:
|
457 |
+
```bash
|
458 |
+
pip install -r requirements.txt
|
459 |
+
```
|
460 |
+
|
461 |
+
2. Run the application:
|
462 |
+
```bash
|
463 |
+
uvicorn main:app --reload
|
464 |
+
```
|
465 |
+
|
466 |
+
3. Access the API:
|
467 |
+
- API: http://localhost:8000
|
468 |
+
- Docs: http://localhost:8000/docs
|
469 |
+
''',
|
470 |
+
'express': '''# Setup Instructions
|
471 |
+
|
472 |
+
1. Install dependencies:
|
473 |
+
```bash
|
474 |
+
npm install
|
475 |
+
```
|
476 |
+
|
477 |
+
2. Run the application:
|
478 |
+
```bash
|
479 |
+
node app.js
|
480 |
+
```
|
481 |
+
|
482 |
+
3. Access the API:
|
483 |
+
- API: http://localhost:3000
|
484 |
+
''',
|
485 |
+
'django': '''# Setup Instructions
|
486 |
+
|
487 |
+
1. Install dependencies:
|
488 |
+
```bash
|
489 |
+
pip install -r requirements.txt
|
490 |
+
```
|
491 |
+
|
492 |
+
2. Run migrations:
|
493 |
+
```bash
|
494 |
+
python manage.py migrate
|
495 |
+
```
|
496 |
+
|
497 |
+
3. Run the application:
|
498 |
+
```bash
|
499 |
+
python manage.py runserver
|
500 |
+
```
|
501 |
+
|
502 |
+
4. Access the API:
|
503 |
+
- API: http://localhost:8000
|
504 |
+
- Admin: http://localhost:8000/admin
|
505 |
+
''',
|
506 |
+
'flask': '''# Setup Instructions
|
507 |
+
|
508 |
+
1. Install dependencies:
|
509 |
+
```bash
|
510 |
+
pip install -r requirements.txt
|
511 |
+
```
|
512 |
+
|
513 |
+
2. Run the application:
|
514 |
+
```bash
|
515 |
+
python run.py
|
516 |
+
```
|
517 |
+
|
518 |
+
3. Access the API:
|
519 |
+
- API: http://localhost:5000
|
520 |
+
'''
|
521 |
+
}
|
522 |
+
|
523 |
+
return instructions.get(framework, '# Setup Instructions\n\nRefer to the framework documentation for setup instructions.')
|
524 |
+
|
525 |
+
# Additional utility endpoints
|
526 |
+
@app.get("/api/v1/frameworks")
|
527 |
+
async def list_supported_frameworks():
|
528 |
+
"""List supported frameworks and languages"""
|
529 |
+
return {
|
530 |
+
"frameworks": [
|
531 |
+
{
|
532 |
+
"name": "fastapi",
|
533 |
+
"language": "python",
|
534 |
+
"description": "Modern, fast, web framework for building APIs"
|
535 |
+
},
|
536 |
+
{
|
537 |
+
"name": "express",
|
538 |
+
"language": "javascript",
|
539 |
+
"description": "Fast, unopinionated web framework for Node.js"
|
540 |
+
},
|
541 |
+
{
|
542 |
+
"name": "django",
|
543 |
+
"language": "python",
|
544 |
+
"description": "High-level Python web framework"
|
545 |
+
},
|
546 |
+
{
|
547 |
+
"name": "flask",
|
548 |
+
"language": "python",
|
549 |
+
"description": "Lightweight WSGI web application framework"
|
550 |
+
}
|
551 |
+
]
|
552 |
+
}
|
553 |
+
|
554 |
+
@app.get("/api/v1/examples")
|
555 |
+
async def get_example_requests():
|
556 |
+
"""Get example generation requests"""
|
557 |
+
return {
|
558 |
+
"examples": [
|
559 |
+
{
|
560 |
+
"name": "E-commerce API",
|
561 |
+
"request": {
|
562 |
+
"description": "Complete e-commerce backend with user management and product catalog",
|
563 |
+
"framework": "fastapi",
|
564 |
+
"language": "python",
|
565 |
+
"requirements": [
|
566 |
+
"User registration and authentication",
|
567 |
+
"Product CRUD operations",
|
568 |
+
"Shopping cart functionality",
|
569 |
+
"Order management",
|
570 |
+
"Payment processing integration"
|
571 |
+
]
|
572 |
+
}
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"name": "Task Management System",
|
576 |
+
"request": {
|
577 |
+
"description": "Task management system with team collaboration",
|
578 |
+
"framework": "express",
|
579 |
+
"language": "javascript",
|
580 |
+
"requirements": [
|
581 |
+
"User authentication with JWT",
|
582 |
+
"Task CRUD operations",
|
583 |
+
"Team and project management",
|
584 |
+
"Real-time notifications",
|
585 |
+
"File attachments"
|
586 |
+
]
|
587 |
+
}
|
588 |
+
},
|
589 |
+
{
|
590 |
+
"name": "Blog Platform",
|
591 |
+
"request": {
|
592 |
+
"description": "Blog platform with content management",
|
593 |
+
"framework": "django",
|
594 |
+
"language": "python",
|
595 |
+
"requirements": [
|
596 |
+
"Article management",
|
597 |
+
"User comments and ratings",
|
598 |
+
"Category and tag system",
|
599 |
+
"SEO optimization",
|
600 |
+
"Media file handling"
|
601 |
+
]
|
602 |
+
}
|
603 |
+
}
|
604 |
+
]
|
605 |
+
}
|
606 |
+
|
607 |
+
if __name__ == "__main__":
|
608 |
+
import uvicorn
|
609 |
+
uvicorn.run(
|
610 |
+
"api_service:app",
|
611 |
+
host="0.0.0.0",
|
612 |
+
port=8000,
|
613 |
+
reload=True
|
614 |
+
)
|
data/.gitkeep
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
evaluation/.gitkeep
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
models/.gitkeep
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
raw_dataset.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
datasets
|
4 |
+
pandas
|
5 |
+
numpy
|
6 |
+
aiohttp
|
7 |
+
requests
|
8 |
+
accelerate
|
9 |
+
fastapi
|
10 |
+
uvicorn
|
11 |
+
python-multipart
|
12 |
+
|
setup-guide.md
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Backend Code Generation Model - Setup & Usage Guide
|
2 |
+
|
3 |
+
## 🛠️ Installation & Setup
|
4 |
+
|
5 |
+
### 1. Install Dependencies
|
6 |
+
```bash
|
7 |
+
pip install torch transformers datasets pandas numpy aiohttp requests
|
8 |
+
pip install accelerate # For faster training
|
9 |
+
```
|
10 |
+
|
11 |
+
### 2. Set Environment Variables
|
12 |
+
```bash
|
13 |
+
# Optional: GitHub token for collecting real repositories
|
14 |
+
export GITHUB_TOKEN="your_github_token_here"
|
15 |
+
|
16 |
+
# For GPU training (if available)
|
17 |
+
export CUDA_VISIBLE_DEVICES=0
|
18 |
+
```
|
19 |
+
|
20 |
+
### 3. Directory Structure
|
21 |
+
```
|
22 |
+
backend-ai-trainer/
|
23 |
+
├── training_pipeline.py # Main pipeline code
|
24 |
+
├── data/
|
25 |
+
│ ├── raw_dataset.json # Collected training data
|
26 |
+
│ └── processed/ # Preprocessed data
|
27 |
+
├── models/
|
28 |
+
│ ├── backend_code_model/ # Trained model output
|
29 |
+
│ └── checkpoints/ # Training checkpoints
|
30 |
+
└── evaluation/
|
31 |
+
├── test_cases.json # Test scenarios
|
32 |
+
└── results/ # Evaluation results
|
33 |
+
```
|
34 |
+
|
35 |
+
## 🏃♂️ Quick Start
|
36 |
+
|
37 |
+
### Option A: Full Automated Pipeline
|
38 |
+
```python
|
39 |
+
import asyncio
|
40 |
+
from training_pipeline import TrainingPipeline
|
41 |
+
|
42 |
+
config = {
|
43 |
+
'base_model': 'microsoft/DialoGPT-medium',
|
44 |
+
'output_dir': './models/backend_code_model',
|
45 |
+
'github_token': 'your_token_here', # Optional
|
46 |
+
}
|
47 |
+
|
48 |
+
pipeline = TrainingPipeline(config)
|
49 |
+
asyncio.run(pipeline.run_full_pipeline())
|
50 |
+
```
|
51 |
+
|
52 |
+
### Option B: Step-by-Step Execution
|
53 |
+
|
54 |
+
#### Step 1: Collect Training Data
|
55 |
+
```python
|
56 |
+
from training_pipeline import DataCollector
|
57 |
+
import asyncio
|
58 |
+
|
59 |
+
collector = DataCollector()
|
60 |
+
|
61 |
+
# Collect from GitHub (requires token)
|
62 |
+
github_queries = [
|
63 |
+
'express api backend',
|
64 |
+
'fastapi python backend',
|
65 |
+
'django rest api',
|
66 |
+
'nodejs backend server',
|
67 |
+
'flask api backend'
|
68 |
+
]
|
69 |
+
|
70 |
+
asyncio.run(collector.collect_github_repositories(github_queries, max_repos=100))
|
71 |
+
|
72 |
+
# Generate synthetic examples
|
73 |
+
collector.generate_synthetic_examples(count=500)
|
74 |
+
|
75 |
+
# Save dataset
|
76 |
+
collector.save_dataset('training_data.json')
|
77 |
+
```
|
78 |
+
|
79 |
+
#### Step 2: Preprocess Data
|
80 |
+
```python
|
81 |
+
from training_pipeline import DataPreprocessor
|
82 |
+
|
83 |
+
preprocessor = DataPreprocessor()
|
84 |
+
processed_examples = preprocessor.preprocess_examples(collector.collected_examples)
|
85 |
+
training_dataset = preprocessor.create_training_dataset(processed_examples)
|
86 |
+
|
87 |
+
print(f"Created dataset with {len(training_dataset)} examples")
|
88 |
+
```
|
89 |
+
|
90 |
+
#### Step 3: Train Model
|
91 |
+
```python
|
92 |
+
from training_pipeline import CodeGenerationModel
|
93 |
+
|
94 |
+
model = CodeGenerationModel('microsoft/DialoGPT-medium')
|
95 |
+
model.fine_tune(training_dataset, output_dir='./trained_model')
|
96 |
+
```
|
97 |
+
|
98 |
+
#### Step 4: Generate Code
|
99 |
+
```python
|
100 |
+
# Generate a complete backend application
|
101 |
+
generated_code = model.generate_code(
|
102 |
+
description="E-commerce API with user authentication and product management",
|
103 |
+
framework="fastapi",
|
104 |
+
language="python"
|
105 |
+
)
|
106 |
+
|
107 |
+
print("Generated Backend Application:")
|
108 |
+
print("=" * 50)
|
109 |
+
print(generated_code)
|
110 |
+
```
|
111 |
+
|
112 |
+
## 🎯 Training Configuration Options
|
113 |
+
|
114 |
+
### Model Selection
|
115 |
+
```python
|
116 |
+
# Lightweight for testing
|
117 |
+
config['base_model'] = 'microsoft/DialoGPT-small'
|
118 |
+
|
119 |
+
# Balanced performance
|
120 |
+
config['base_model'] = 'microsoft/DialoGPT-medium'
|
121 |
+
|
122 |
+
# High quality (requires more resources)
|
123 |
+
config['base_model'] = 'microsoft/DialoGPT-large'
|
124 |
+
```
|
125 |
+
|
126 |
+
### Training Parameters
|
127 |
+
```python
|
128 |
+
training_config = {
|
129 |
+
'num_epochs': 5, # More epochs = better learning
|
130 |
+
'batch_size': 4, # Adjust based on GPU memory
|
131 |
+
'learning_rate': 5e-5, # Conservative learning rate
|
132 |
+
'max_length': 2048, # Maximum token length
|
133 |
+
'warmup_steps': 500, # Learning rate warmup
|
134 |
+
'save_steps': 1000, # Checkpoint frequency
|
135 |
+
}
|
136 |
+
```
|
137 |
+
|
138 |
+
### Framework Coverage
|
139 |
+
The pipeline supports these backend frameworks:
|
140 |
+
|
141 |
+
**Node.js Frameworks:**
|
142 |
+
- Express.js - Most popular Node.js framework
|
143 |
+
- NestJS - Enterprise-grade framework
|
144 |
+
- Koa.js - Lightweight alternative
|
145 |
+
|
146 |
+
**Python Frameworks:**
|
147 |
+
- FastAPI - Modern, high-performance API framework
|
148 |
+
- Django - Full-featured web framework
|
149 |
+
- Flask - Lightweight and flexible
|
150 |
+
|
151 |
+
**Go Frameworks:**
|
152 |
+
- Gin - HTTP web framework
|
153 |
+
- Fiber - Express-inspired framework
|
154 |
+
|
155 |
+
## 📊 Evaluation & Testing
|
156 |
+
|
157 |
+
### Automatic Quality Assessment
|
158 |
+
```python
|
159 |
+
from training_pipeline import ModelEvaluator
|
160 |
+
|
161 |
+
evaluator = ModelEvaluator()
|
162 |
+
|
163 |
+
# Test specific code generation
|
164 |
+
generated_code = model.generate_code(
|
165 |
+
description="User authentication API with JWT tokens",
|
166 |
+
framework="express",
|
167 |
+
language="javascript"
|
168 |
+
)
|
169 |
+
|
170 |
+
# Get quality scores
|
171 |
+
quality_scores = evaluator.evaluate_code_quality(generated_code, "javascript")
|
172 |
+
print(f"Syntax Correctness: {quality_scores['syntax_correctness']:.2f}")
|
173 |
+
print(f"Completeness: {quality_scores['completeness']:.2f}")
|
174 |
+
print(f"Best Practices: {quality_scores['best_practices']:.2f}")
|
175 |
+
```
|
176 |
+
|
177 |
+
### Comprehensive Benchmarking
|
178 |
+
```python
|
179 |
+
test_cases = [
|
180 |
+
{
|
181 |
+
'description': 'REST API for task management with user authentication',
|
182 |
+
'framework': 'express',
|
183 |
+
'language': 'javascript'
|
184 |
+
},
|
185 |
+
{
|
186 |
+
'description': 'GraphQL API for social media platform',
|
187 |
+
'framework': 'fastapi',
|
188 |
+
'language': 'python'
|
189 |
+
},
|
190 |
+
{
|
191 |
+
'description': 'Microservice for payment processing',
|
192 |
+
'framework': 'gin',
|
193 |
+
'language': 'go'
|
194 |
+
}
|
195 |
+
]
|
196 |
+
|
197 |
+
benchmark_results = evaluator.benchmark_model(model, test_cases)
|
198 |
+
print("Overall Performance:", benchmark_results)
|
199 |
+
```
|
200 |
+
|
201 |
+
## 🚀 Advanced Usage
|
202 |
+
|
203 |
+
### Custom Data Sources
|
204 |
+
```python
|
205 |
+
# Add your own training examples
|
206 |
+
custom_examples = [
|
207 |
+
{
|
208 |
+
'description': 'Custom API requirement',
|
209 |
+
'requirements': ['Custom feature 1', 'Custom feature 2'],
|
210 |
+
'framework': 'fastapi',
|
211 |
+
'language': 'python',
|
212 |
+
'code_files': {
|
213 |
+
'main.py': '# Your custom code here',
|
214 |
+
'requirements.txt': 'fastapi\nuvicorn'
|
215 |
+
}
|
216 |
+
}
|
217 |
+
]
|
218 |
+
|
219 |
+
# Add to training data
|
220 |
+
collector.collected_examples.extend([CodeExample(**ex) for ex in custom_examples])
|
221 |
+
```
|
222 |
+
|
223 |
+
### Fine-tuning on Specific Domains
|
224 |
+
```python
|
225 |
+
# Focus training on specific application types
|
226 |
+
domain_specific_queries = [
|
227 |
+
'microservices architecture',
|
228 |
+
'api gateway implementation',
|
229 |
+
'database orm integration',
|
230 |
+
'authentication middleware',
|
231 |
+
'rate limiting api'
|
232 |
+
]
|
233 |
+
|
234 |
+
asyncio.run(collector.collect_github_repositories(domain_specific_queries))
|
235 |
+
```
|
236 |
+
|
237 |
+
### Export Trained Model
|
238 |
+
```python
|
239 |
+
# Save model for deployment
|
240 |
+
model.model.save_pretrained('./production_model')
|
241 |
+
model.tokenizer.save_pretrained('./production_model')
|
242 |
+
|
243 |
+
# Load for inference
|
244 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
245 |
+
|
246 |
+
production_model = AutoModelForCausalLM.from_pretrained('./production_model')
|
247 |
+
production_tokenizer = AutoTokenizer.from_pretrained('./production_model')
|
248 |
+
```
|
249 |
+
|
250 |
+
## 🔧 Troubleshooting
|
251 |
+
|
252 |
+
### Common Issues
|
253 |
+
|
254 |
+
**1. Out of Memory Errors**
|
255 |
+
```python
|
256 |
+
# Reduce batch size
|
257 |
+
config['per_device_train_batch_size'] = 1
|
258 |
+
config['gradient_accumulation_steps'] = 4
|
259 |
+
|
260 |
+
# Use gradient checkpointing
|
261 |
+
config['gradient_checkpointing'] = True
|
262 |
+
```
|
263 |
+
|
264 |
+
**2. Slow Training**
|
265 |
+
```python
|
266 |
+
# Enable mixed precision (if GPU supports it)
|
267 |
+
config['fp16'] = True
|
268 |
+
|
269 |
+
# Use multiple GPUs
|
270 |
+
config['dataloader_num_workers'] = 4
|
271 |
+
```
|
272 |
+
|
273 |
+
**3. Poor Code Quality**
|
274 |
+
```python
|
275 |
+
# Increase training data diversity
|
276 |
+
collector.generate_synthetic_examples(count=1000)
|
277 |
+
|
278 |
+
# Extend training duration
|
279 |
+
config['num_train_epochs'] = 10
|
280 |
+
```
|
281 |
+
|
282 |
+
### Performance Optimization
|
283 |
+
|
284 |
+
**For CPU Training:**
|
285 |
+
```python
|
286 |
+
config['dataloader_pin_memory'] = False
|
287 |
+
config['per_device_train_batch_size'] = 1
|
288 |
+
```
|
289 |
+
|
290 |
+
**For GPU Training:**
|
291 |
+
```python
|
292 |
+
config['fp16'] = True
|
293 |
+
config['dataloader_pin_memory'] = True
|
294 |
+
config['per_device_train_batch_size'] = 4
|
295 |
+
```
|
296 |
+
|
297 |
+
## 📈 Expected Results
|
298 |
+
|
299 |
+
After training on ~500-1000 examples, you should expect:
|
300 |
+
|
301 |
+
- **Syntax Correctness**: 85-95%
|
302 |
+
- **Code Completeness**: 80-90%
|
303 |
+
- **Best Practices**: 70-85%
|
304 |
+
- **Framework Coverage**: All major Node.js and Python frameworks
|
305 |
+
- **Generation Speed**: 2-5 seconds per application
|
306 |
+
|
307 |
+
## 🔄 Continuous Improvement
|
308 |
+
|
309 |
+
### Regular Retraining
|
310 |
+
```python
|
311 |
+
# Schedule weekly data collection
|
312 |
+
import schedule
|
313 |
+
|
314 |
+
def update_training_data():
|
315 |
+
asyncio.run(collector.collect_github_repositories(['new backend trends']))
|
316 |
+
|
317 |
+
schedule.every().week.do(update_training_data)
|
318 |
+
```
|
319 |
+
|
320 |
+
### A/B Testing Different Models
|
321 |
+
```python
|
322 |
+
models_to_compare = [
|
323 |
+
'microsoft/DialoGPT-medium',
|
324 |
+
'microsoft/DialoGPT-large',
|
325 |
+
'gpt2-medium'
|
326 |
+
]
|
327 |
+
|
328 |
+
for base_model in models_to_compare:
|
329 |
+
model = CodeGenerationModel(base_model)
|
330 |
+
results = evaluator.benchmark_model(model, test_cases)
|
331 |
+
print(f"{base_model}: {results}")
|
332 |
+
```
|
333 |
+
|
334 |
+
## 🎯 Next Steps
|
335 |
+
|
336 |
+
1. **Start Small**: Begin with synthetic data and 100-200 examples
|
337 |
+
2. **Add Real Data**: Integrate GitHub repositories gradually
|
338 |
+
3. **Evaluate Regularly**: Monitor quality metrics after each training session
|
339 |
+
4. **Expand Frameworks**: Add support for new frameworks as needed
|
340 |
+
5. **Production Deploy**: Export model for API deployment
|
341 |
+
|
342 |
+
This pipeline provides a complete foundation for building your own backend code generation AI. The modular design allows you to customize and extend each component based on your specific needs.
|
test_api.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script for the Backend Code Generation API
|
4 |
+
===============================================
|
5 |
+
|
6 |
+
Simple test script to verify the API is working correctly.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import requests
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
import os
|
13 |
+
|
14 |
+
# API base URL
|
15 |
+
BASE_URL = "http://localhost:8000"
|
16 |
+
|
17 |
+
def test_health():
|
18 |
+
"""Test the health endpoint"""
|
19 |
+
print("Testing health endpoint...")
|
20 |
+
response = requests.get(f"{BASE_URL}/health")
|
21 |
+
print(f"Status: {response.status_code}")
|
22 |
+
print(f"Response: {response.json()}")
|
23 |
+
return response.status_code == 200
|
24 |
+
|
25 |
+
def test_generate_code():
|
26 |
+
"""Test code generation"""
|
27 |
+
print("\nTesting code generation...")
|
28 |
+
|
29 |
+
# Test request
|
30 |
+
request_data = {
|
31 |
+
"description": "Simple REST API for task management",
|
32 |
+
"framework": "fastapi",
|
33 |
+
"language": "python",
|
34 |
+
"requirements": [
|
35 |
+
"User authentication",
|
36 |
+
"Task CRUD operations",
|
37 |
+
"Task status tracking"
|
38 |
+
],
|
39 |
+
"project_name": "task-manager-api"
|
40 |
+
}
|
41 |
+
|
42 |
+
# Submit generation request
|
43 |
+
response = requests.post(f"{BASE_URL}/api/v1/generate", json=request_data)
|
44 |
+
print(f"Generation request status: {response.status_code}")
|
45 |
+
|
46 |
+
if response.status_code == 200:
|
47 |
+
result = response.json()
|
48 |
+
task_id = result["task_id"]
|
49 |
+
print(f"Task ID: {task_id}")
|
50 |
+
|
51 |
+
# Poll for completion
|
52 |
+
print("Polling for completion...")
|
53 |
+
for i in range(30): # Wait up to 5 minutes
|
54 |
+
status_response = requests.get(f"{BASE_URL}/api/v1/status/{task_id}")
|
55 |
+
if status_response.status_code == 200:
|
56 |
+
status = status_response.json()
|
57 |
+
print(f"Status: {status['status']} - {status['message']} ({status['progress']}%)")
|
58 |
+
|
59 |
+
if status["status"] == "completed":
|
60 |
+
print("✅ Generation completed!")
|
61 |
+
if status.get("download_url"):
|
62 |
+
print(f"Download URL: {status['download_url']}")
|
63 |
+
return True
|
64 |
+
elif status["status"] == "failed":
|
65 |
+
print(f"❌ Generation failed: {status.get('error', 'Unknown error')}")
|
66 |
+
return False
|
67 |
+
else:
|
68 |
+
print(f"Failed to get status: {status_response.status_code}")
|
69 |
+
return False
|
70 |
+
|
71 |
+
time.sleep(10) # Wait 10 seconds between polls
|
72 |
+
|
73 |
+
print("⏰ Timeout waiting for completion")
|
74 |
+
return False
|
75 |
+
else:
|
76 |
+
print(f"❌ Generation request failed: {response.text}")
|
77 |
+
return False
|
78 |
+
|
79 |
+
def test_frameworks():
|
80 |
+
"""Test frameworks endpoint"""
|
81 |
+
print("\nTesting frameworks endpoint...")
|
82 |
+
response = requests.get(f"{BASE_URL}/api/v1/frameworks")
|
83 |
+
print(f"Status: {response.status_code}")
|
84 |
+
if response.status_code == 200:
|
85 |
+
frameworks = response.json()
|
86 |
+
print(f"Supported frameworks: {len(frameworks['frameworks'])}")
|
87 |
+
for fw in frameworks['frameworks']:
|
88 |
+
print(f" - {fw['name']} ({fw['language']})")
|
89 |
+
return True
|
90 |
+
return False
|
91 |
+
|
92 |
+
def test_examples():
|
93 |
+
"""Test examples endpoint"""
|
94 |
+
print("\nTesting examples endpoint...")
|
95 |
+
response = requests.get(f"{BASE_URL}/api/v1/examples")
|
96 |
+
print(f"Status: {response.status_code}")
|
97 |
+
if response.status_code == 200:
|
98 |
+
examples = response.json()
|
99 |
+
print(f"Available examples: {len(examples['examples'])}")
|
100 |
+
for ex in examples['examples']:
|
101 |
+
print(f" - {ex['name']}")
|
102 |
+
return True
|
103 |
+
return False
|
104 |
+
|
105 |
+
def main():
|
106 |
+
"""Run all tests"""
|
107 |
+
print("🚀 Testing Backend Code Generation API")
|
108 |
+
print("=" * 50)
|
109 |
+
|
110 |
+
# Check if API is running
|
111 |
+
try:
|
112 |
+
response = requests.get(f"{BASE_URL}/", timeout=5)
|
113 |
+
if response.status_code != 200:
|
114 |
+
print("❌ API is not running. Please start it with: python api_service.py")
|
115 |
+
return
|
116 |
+
except requests.exceptions.RequestException:
|
117 |
+
print("❌ Cannot connect to API. Please start it with: python api_service.py")
|
118 |
+
return
|
119 |
+
|
120 |
+
print("✅ API is running")
|
121 |
+
|
122 |
+
# Run tests
|
123 |
+
tests = [
|
124 |
+
("Health Check", test_health),
|
125 |
+
("Frameworks List", test_frameworks),
|
126 |
+
("Examples List", test_examples),
|
127 |
+
("Code Generation", test_generate_code),
|
128 |
+
]
|
129 |
+
|
130 |
+
results = []
|
131 |
+
for test_name, test_func in tests:
|
132 |
+
print(f"\n{'='*20} {test_name} {'='*20}")
|
133 |
+
try:
|
134 |
+
result = test_func()
|
135 |
+
results.append((test_name, result))
|
136 |
+
except Exception as e:
|
137 |
+
print(f"❌ Test failed with error: {e}")
|
138 |
+
results.append((test_name, False))
|
139 |
+
|
140 |
+
# Summary
|
141 |
+
print(f"\n{'='*50}")
|
142 |
+
print("📊 Test Results Summary:")
|
143 |
+
print("=" * 50)
|
144 |
+
|
145 |
+
passed = 0
|
146 |
+
for test_name, result in results:
|
147 |
+
status = "✅ PASS" if result else "❌ FAIL"
|
148 |
+
print(f"{test_name}: {status}")
|
149 |
+
if result:
|
150 |
+
passed += 1
|
151 |
+
|
152 |
+
print(f"\nPassed: {passed}/{len(results)} tests")
|
153 |
+
|
154 |
+
if passed == len(results):
|
155 |
+
print("🎉 All tests passed!")
|
156 |
+
else:
|
157 |
+
print("⚠️ Some tests failed. Check the output above for details.")
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
main()
|
training_pipeline.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Backend Code Generation Model Training Pipeline
|
4 |
+
===============================================
|
5 |
+
|
6 |
+
A comprehensive training pipeline for building an AI model that generates
|
7 |
+
framework-agnostic backend code with full application scaffolding.
|
8 |
+
|
9 |
+
Features:
|
10 |
+
- Data collection from multiple sources
|
11 |
+
- Multi-framework support (Express.js, FastAPI, Django, Flask, etc.)
|
12 |
+
- Full application scaffolding generation
|
13 |
+
- Model training with transformer architecture
|
14 |
+
- Evaluation and benchmarking tools
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import json
|
19 |
+
import logging
|
20 |
+
import asyncio
|
21 |
+
import aiohttp
|
22 |
+
import pandas as pd
|
23 |
+
import numpy as np
|
24 |
+
from typing import Dict, List, Optional, Tuple, Any
|
25 |
+
from dataclasses import dataclass, asdict
|
26 |
+
from pathlib import Path
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
from torch.utils.data import Dataset, DataLoader
|
30 |
+
from transformers import (
|
31 |
+
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
|
32 |
+
Trainer, DataCollatorForLanguageModeling
|
33 |
+
)
|
34 |
+
from datasets import Dataset as HFDataset
|
35 |
+
import ast
|
36 |
+
import subprocess
|
37 |
+
import tempfile
|
38 |
+
from concurrent.futures import ThreadPoolExecutor
|
39 |
+
import requests
|
40 |
+
import time
|
41 |
+
import random
|
42 |
+
|
43 |
+
# Configure logging
|
44 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
45 |
+
logger = logging.getLogger(__name__)
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class CodeExample:
|
50 |
+
"""Represents a single training example"""
|
51 |
+
description: str
|
52 |
+
requirements: List[str]
|
53 |
+
framework: str
|
54 |
+
language: str
|
55 |
+
code_files: Dict[str, str] # filename -> content
|
56 |
+
project_structure: Dict[str, Any]
|
57 |
+
metadata: Dict[str, Any]
|
58 |
+
|
59 |
+
|
60 |
+
class DataCollector:
|
61 |
+
"""Collects training data from various sources"""
|
62 |
+
|
63 |
+
def __init__(self):
|
64 |
+
self.github_token = os.getenv('GITHUB_TOKEN')
|
65 |
+
self.collected_examples: List[CodeExample] = []
|
66 |
+
|
67 |
+
async def collect_github_repositories(self, queries: List[str], max_repos: int = 100):
|
68 |
+
"""Collect backend projects from GitHub"""
|
69 |
+
logger.info("Starting GitHub repository collection...")
|
70 |
+
|
71 |
+
headers = {'Authorization': f'token {self.github_token}'} if self.github_token else {}
|
72 |
+
|
73 |
+
async with aiohttp.ClientSession(headers=headers) as session:
|
74 |
+
per_query = max(1, max_repos // max(1, len(queries)))
|
75 |
+
for query in queries:
|
76 |
+
await self._search_github_repos(session, query, per_query)
|
77 |
+
|
78 |
+
async def _search_github_repos(self, session: aiohttp.ClientSession, query: str, limit: int):
|
79 |
+
"""Search GitHub for repositories matching query"""
|
80 |
+
url = f"https://api.github.com/search/repositories"
|
81 |
+
params = {
|
82 |
+
'q': query,
|
83 |
+
'sort': 'stars',
|
84 |
+
'order': 'desc',
|
85 |
+
'per_page': min(limit, 100)
|
86 |
+
}
|
87 |
+
|
88 |
+
try:
|
89 |
+
async with session.get(url, params=params) as response:
|
90 |
+
if response.status == 200:
|
91 |
+
data = await response.json()
|
92 |
+
for repo in data.get('items', []):
|
93 |
+
await self._process_repository(session, repo)
|
94 |
+
else:
|
95 |
+
logger.warning(f"GitHub API request failed: {response.status}")
|
96 |
+
except Exception as e:
|
97 |
+
logger.error(f"Error searching GitHub: {e}")
|
98 |
+
|
99 |
+
async def _process_repository(self, session: aiohttp.ClientSession, repo: Dict):
|
100 |
+
"""Process a single repository to extract code examples"""
|
101 |
+
logger.info(f"Processing repository: {repo.get('full_name', '<unknown>')}")
|
102 |
+
|
103 |
+
try:
|
104 |
+
contents_url = f"https://api.github.com/repos/{repo['full_name']}/contents"
|
105 |
+
async with session.get(contents_url) as response:
|
106 |
+
if response.status == 200:
|
107 |
+
contents = await response.json()
|
108 |
+
await self._extract_code_example(session, repo, contents)
|
109 |
+
except Exception as e:
|
110 |
+
logger.error(f"Error processing repository {repo.get('full_name')}: {e}")
|
111 |
+
|
112 |
+
async def _extract_code_example(self, session: aiohttp.ClientSession, repo: Dict, contents: List[Dict]):
|
113 |
+
"""Extract a structured code example from repository"""
|
114 |
+
framework = self._identify_framework(contents, repo.get('description', ''))
|
115 |
+
language = self._identify_language(contents)
|
116 |
+
|
117 |
+
if not framework or not language:
|
118 |
+
return
|
119 |
+
|
120 |
+
code_files: Dict[str, str] = {}
|
121 |
+
for item in contents:
|
122 |
+
if item.get('type') == 'file' and self._is_important_file(item.get('name', '')):
|
123 |
+
try:
|
124 |
+
async with session.get(item['download_url']) as response:
|
125 |
+
if response.status == 200:
|
126 |
+
content = await response.text()
|
127 |
+
code_files[item['name']] = content
|
128 |
+
except Exception:
|
129 |
+
continue
|
130 |
+
|
131 |
+
if code_files:
|
132 |
+
example = CodeExample(
|
133 |
+
description=repo.get('description', ''),
|
134 |
+
requirements=self._extract_requirements(code_files),
|
135 |
+
framework=framework,
|
136 |
+
language=language,
|
137 |
+
code_files=code_files,
|
138 |
+
project_structure=self._analyze_structure(contents),
|
139 |
+
metadata={
|
140 |
+
'stars': repo.get('stargazers_count', 0),
|
141 |
+
'forks': repo.get('forks_count', 0),
|
142 |
+
'url': repo.get('html_url'),
|
143 |
+
'created_at': repo.get('created_at'),
|
144 |
+
'updated_at': repo.get('updated_at')
|
145 |
+
}
|
146 |
+
)
|
147 |
+
self.collected_examples.append(example)
|
148 |
+
|
149 |
+
def _identify_framework(self, contents: List[Dict], description: str) -> Optional[str]:
|
150 |
+
"""Identify the backend framework used"""
|
151 |
+
filenames = [item.get('name', '').lower() for item in contents if item.get('type') == 'file']
|
152 |
+
|
153 |
+
frameworks = {
|
154 |
+
'express': ['package.json', 'app.js', 'server.js'],
|
155 |
+
'fastapi': ['requirements.txt', 'main.py', 'app.py'],
|
156 |
+
'django': ['manage.py', 'settings.py', 'requirements.txt'],
|
157 |
+
'flask': ['app.py', 'requirements.txt'],
|
158 |
+
'nestjs': ['nest-cli.json', 'package.json'],
|
159 |
+
'koa': ['package.json'],
|
160 |
+
'gin': ['go.mod', 'main.go'],
|
161 |
+
'fiber': ['go.mod', 'main.go'],
|
162 |
+
}
|
163 |
+
|
164 |
+
for framework, required_files in frameworks.items():
|
165 |
+
if all(any(req in filename for filename in filenames) for req in required_files[:2]):
|
166 |
+
return framework
|
167 |
+
|
168 |
+
desc_lower = description.lower()
|
169 |
+
for framework in frameworks.keys():
|
170 |
+
if framework in desc_lower:
|
171 |
+
return framework
|
172 |
+
|
173 |
+
return None
|
174 |
+
|
175 |
+
def _identify_language(self, contents: List[Dict]) -> Optional[str]:
|
176 |
+
"""Identify primary programming language"""
|
177 |
+
extensions: Dict[str, int] = {}
|
178 |
+
for item in contents:
|
179 |
+
if item.get('type') == 'file':
|
180 |
+
ext = Path(item.get('name', '')).suffix.lower()
|
181 |
+
if ext:
|
182 |
+
extensions[ext] = extensions.get(ext, 0) + 1
|
183 |
+
|
184 |
+
lang_map = {
|
185 |
+
'.js': 'javascript',
|
186 |
+
'.ts': 'typescript',
|
187 |
+
'.py': 'python',
|
188 |
+
'.go': 'go',
|
189 |
+
'.java': 'java',
|
190 |
+
'.cs': 'csharp',
|
191 |
+
'.rb': 'ruby',
|
192 |
+
'.php': 'php'
|
193 |
+
}
|
194 |
+
|
195 |
+
if extensions:
|
196 |
+
most_common_ext = max(extensions.items(), key=lambda x: x[1])[0]
|
197 |
+
return lang_map.get(most_common_ext)
|
198 |
+
|
199 |
+
return None
|
200 |
+
|
201 |
+
def _is_important_file(self, filename: str) -> bool:
|
202 |
+
"""Check if file is important for training"""
|
203 |
+
important_patterns = [
|
204 |
+
'package.json', 'requirements.txt', 'go.mod', 'pom.xml',
|
205 |
+
'dockerfile', 'docker-compose.yml', 'readme.md',
|
206 |
+
'app.py', 'main.py', 'server.js', 'app.js', 'index.js',
|
207 |
+
'settings.py', 'config.py', 'routes.py', 'models.py',
|
208 |
+
'controller.js', 'service.js', 'middleware.js'
|
209 |
+
]
|
210 |
+
|
211 |
+
filename_lower = filename.lower()
|
212 |
+
return any(pattern in filename_lower for pattern in important_patterns)
|
213 |
+
|
214 |
+
def _extract_requirements(self, code_files: Dict[str, str]) -> List[str]:
|
215 |
+
"""Extract functional requirements from code"""
|
216 |
+
requirements: List[str] = []
|
217 |
+
|
218 |
+
if 'package.json' in code_files:
|
219 |
+
try:
|
220 |
+
pkg_data = json.loads(code_files['package.json'])
|
221 |
+
deps = list(pkg_data.get('dependencies', {}).keys())
|
222 |
+
requirements.extend([f"Uses {dep}" for dep in deps[:5]])
|
223 |
+
except Exception:
|
224 |
+
pass
|
225 |
+
|
226 |
+
if 'requirements.txt' in code_files:
|
227 |
+
lines = code_files['requirements.txt'].strip().split('\n')
|
228 |
+
deps = [line.split('==')[0].split('>=')[0].strip() for line in lines if line.strip()]
|
229 |
+
requirements.extend([f"Uses {dep}" for dep in deps[:5]])
|
230 |
+
|
231 |
+
for filename, content in code_files.items():
|
232 |
+
if filename.endswith(('.js', '.py')):
|
233 |
+
endpoints = self._extract_endpoints(content)
|
234 |
+
requirements.extend(endpoints)
|
235 |
+
|
236 |
+
return requirements[:10]
|
237 |
+
|
238 |
+
def _extract_endpoints(self, code_content: str) -> List[str]:
|
239 |
+
"""Extract API endpoints from code"""
|
240 |
+
endpoints: List[str] = []
|
241 |
+
lines = code_content.split('\n')
|
242 |
+
|
243 |
+
for line in lines:
|
244 |
+
s = line.strip()
|
245 |
+
if any(method in s for method in ['app.get(', 'app.post(', 'app.put(', 'app.delete(']):
|
246 |
+
endpoints.append(f"Implements {s}")
|
247 |
+
elif any(decorator in s for decorator in ['@app.get(', '@app.post(', '@app.put(', '@app.delete(']):
|
248 |
+
endpoints.append(f"Implements {s}")
|
249 |
+
elif 'def ' in s and any(word in s for word in ['get', 'post', 'put', 'delete']):
|
250 |
+
endpoints.append(f"Implements {s}")
|
251 |
+
|
252 |
+
return endpoints[:5]
|
253 |
+
|
254 |
+
def _analyze_structure(self, contents: List[Dict]) -> Dict[str, Any]:
|
255 |
+
"""Analyze project structure"""
|
256 |
+
structure: Dict[str, Any] = {
|
257 |
+
'files': [],
|
258 |
+
'directories': [],
|
259 |
+
'total_files': 0,
|
260 |
+
'has_tests': False,
|
261 |
+
'has_docs': False
|
262 |
+
}
|
263 |
+
|
264 |
+
for item in contents:
|
265 |
+
if item.get('type') == 'file':
|
266 |
+
name = item.get('name', '')
|
267 |
+
structure['files'].append(name)
|
268 |
+
structure['total_files'] += 1
|
269 |
+
if 'test' in name.lower():
|
270 |
+
structure['has_tests'] = True
|
271 |
+
if name.lower() in ['readme.md', 'docs.md']:
|
272 |
+
structure['has_docs'] = True
|
273 |
+
elif item.get('type') == 'dir':
|
274 |
+
structure['directories'].append(item.get('name', ''))
|
275 |
+
|
276 |
+
return structure
|
277 |
+
|
278 |
+
def generate_synthetic_examples(self, count: int = 100):
|
279 |
+
"""Generate synthetic training examples"""
|
280 |
+
logger.info(f"Generating {count} synthetic examples...")
|
281 |
+
|
282 |
+
templates = [
|
283 |
+
{
|
284 |
+
'description': 'REST API for user management',
|
285 |
+
'requirements': ['User registration', 'User authentication', 'Profile management'],
|
286 |
+
'frameworks': ['express', 'fastapi', 'django']
|
287 |
+
},
|
288 |
+
{
|
289 |
+
'description': 'E-commerce backend API',
|
290 |
+
'requirements': ['Product catalog', 'Shopping cart', 'Order processing', 'Payment integration'],
|
291 |
+
'frameworks': ['nestjs', 'fastapi', 'django']
|
292 |
+
},
|
293 |
+
{
|
294 |
+
'description': 'Task management system',
|
295 |
+
'requirements': ['Task CRUD operations', 'User assignments', 'Status tracking'],
|
296 |
+
'frameworks': ['express', 'flask', 'gin']
|
297 |
+
},
|
298 |
+
{
|
299 |
+
'description': 'Blog platform backend',
|
300 |
+
'requirements': ['Article management', 'User comments', 'Category system'],
|
301 |
+
'frameworks': ['express', 'django', 'fastapi']
|
302 |
+
}
|
303 |
+
]
|
304 |
+
|
305 |
+
for _ in range(count):
|
306 |
+
template = random.choice(templates)
|
307 |
+
framework = random.choice(template['frameworks'])
|
308 |
+
|
309 |
+
code_files = self._generate_code_for_template(template, framework)
|
310 |
+
|
311 |
+
example = CodeExample(
|
312 |
+
description=template['description'],
|
313 |
+
requirements=template['requirements'],
|
314 |
+
framework=framework,
|
315 |
+
language='python' if framework in ['fastapi', 'django', 'flask'] else 'javascript',
|
316 |
+
code_files=code_files,
|
317 |
+
project_structure=self._generate_synthetic_structure(framework),
|
318 |
+
metadata={'synthetic': True}
|
319 |
+
)
|
320 |
+
|
321 |
+
self.collected_examples.append(example)
|
322 |
+
|
323 |
+
def _generate_code_for_template(self, template: Dict, framework: str) -> Dict[str, str]:
|
324 |
+
"""Generate code files for a template and framework"""
|
325 |
+
if framework == 'express':
|
326 |
+
return {
|
327 |
+
'package.json': json.dumps({
|
328 |
+
"name": template['description'].lower().replace(' ', '-'),
|
329 |
+
"version": "1.0.0",
|
330 |
+
"dependencies": {
|
331 |
+
"express": "^4.18.0",
|
332 |
+
"mongoose": "^6.0.0",
|
333 |
+
"bcrypt": "^5.0.0",
|
334 |
+
"jsonwebtoken": "^8.5.0"
|
335 |
+
}
|
336 |
+
}, indent=2),
|
337 |
+
'app.js': '''const express = require('express');
|
338 |
+
const mongoose = require('mongoose');
|
339 |
+
const app = express();
|
340 |
+
|
341 |
+
// Middleware
|
342 |
+
app.use(express.json());
|
343 |
+
|
344 |
+
// Routes
|
345 |
+
app.get('/health', (req, res) => {
|
346 |
+
res.json({ status: 'OK' });
|
347 |
+
});
|
348 |
+
|
349 |
+
// Start server
|
350 |
+
const PORT = process.env.PORT || 3000;
|
351 |
+
app.listen(PORT, () => {
|
352 |
+
console.log(`Server running on port ${PORT}`);
|
353 |
+
});
|
354 |
+
|
355 |
+
module.exports = app;'''
|
356 |
+
}
|
357 |
+
elif framework == 'fastapi':
|
358 |
+
return {
|
359 |
+
'requirements.txt': '''fastapi==0.68.0
|
360 |
+
uvicorn==0.15.0
|
361 |
+
sqlalchemy==1.4.23
|
362 |
+
pydantic==1.8.2''',
|
363 |
+
'main.py': '''from fastapi import FastAPI, HTTPException
|
364 |
+
from pydantic import BaseModel
|
365 |
+
from typing import List, Optional
|
366 |
+
|
367 |
+
app = FastAPI()
|
368 |
+
|
369 |
+
class Item(BaseModel):
|
370 |
+
id: Optional[int] = None
|
371 |
+
name: str
|
372 |
+
description: str
|
373 |
+
|
374 |
+
@app.get("/")
|
375 |
+
async def root():
|
376 |
+
return {"message": "Hello World"}
|
377 |
+
|
378 |
+
@app.get("/health")
|
379 |
+
async def health_check():
|
380 |
+
return {"status": "OK"}
|
381 |
+
|
382 |
+
if __name__ == "__main__":
|
383 |
+
import uvicorn
|
384 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)'''
|
385 |
+
}
|
386 |
+
else:
|
387 |
+
return {'placeholder.txt': 'Generated code placeholder'}
|
388 |
+
|
389 |
+
def _generate_synthetic_structure(self, framework: str) -> Dict[str, Any]:
|
390 |
+
"""Generate project structure for framework"""
|
391 |
+
if framework in ['express', 'nestjs']:
|
392 |
+
return {
|
393 |
+
'files': ['package.json', 'app.js', 'README.md'],
|
394 |
+
'directories': ['routes', 'controllers', 'middleware', 'models'],
|
395 |
+
'total_files': 3,
|
396 |
+
'has_tests': True,
|
397 |
+
'has_docs': True
|
398 |
+
}
|
399 |
+
elif framework in ['fastapi', 'django', 'flask']:
|
400 |
+
return {
|
401 |
+
'files': ['requirements.txt', 'main.py', 'README.md'],
|
402 |
+
'directories': ['models', 'routes', 'services'],
|
403 |
+
'total_files': 3,
|
404 |
+
'has_tests': True,
|
405 |
+
'has_docs': True
|
406 |
+
}
|
407 |
+
else:
|
408 |
+
return {}
|
409 |
+
|
410 |
+
def save_dataset(self, filepath: str):
|
411 |
+
"""Save collected examples to file"""
|
412 |
+
data = [asdict(example) for example in self.collected_examples]
|
413 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
414 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
415 |
+
logger.info(f"Saved {len(data)} examples to {filepath}")
|
416 |
+
|
417 |
+
|
418 |
+
class DataPreprocessor:
|
419 |
+
"""Preprocesses collected data for training"""
|
420 |
+
|
421 |
+
def __init__(self, tokenizer_name: str = "microsoft/DialoGPT-medium"):
|
422 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
423 |
+
if self.tokenizer.pad_token is None:
|
424 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
425 |
+
# Ensure we do not exceed model's maximum positional embeddings (GPT-2/DialoGPT: 1024)
|
426 |
+
try:
|
427 |
+
model_max = getattr(self.tokenizer, 'model_max_length', 1024)
|
428 |
+
# Some tokenizers set a very large sentinel value; cap at 1024 for GPT-2 family
|
429 |
+
if model_max and model_max > 0 and model_max < 100000:
|
430 |
+
self.max_length = min(1024, int(model_max))
|
431 |
+
else:
|
432 |
+
self.max_length = 1024
|
433 |
+
except Exception:
|
434 |
+
self.max_length = 1024
|
435 |
+
|
436 |
+
def preprocess_examples(self, examples: List[CodeExample]) -> List[Dict[str, str]]:
|
437 |
+
"""Convert examples to training format"""
|
438 |
+
processed: List[Dict[str, str]] = []
|
439 |
+
|
440 |
+
for example in examples:
|
441 |
+
input_text = self._create_input_text(example)
|
442 |
+
output_text = self._create_output_text(example)
|
443 |
+
|
444 |
+
processed.append({
|
445 |
+
'input': input_text,
|
446 |
+
'output': output_text,
|
447 |
+
'framework': example.framework,
|
448 |
+
'language': example.language
|
449 |
+
})
|
450 |
+
|
451 |
+
return processed
|
452 |
+
|
453 |
+
def _create_input_text(self, example: CodeExample) -> str:
|
454 |
+
"""Create model input text"""
|
455 |
+
input_parts: List[str] = [
|
456 |
+
f"Description: {example.description}",
|
457 |
+
f"Framework: {example.framework}",
|
458 |
+
f"Language: {example.language}",
|
459 |
+
"Requirements:",
|
460 |
+
]
|
461 |
+
|
462 |
+
for req in example.requirements:
|
463 |
+
input_parts.append(f"- {req}")
|
464 |
+
|
465 |
+
input_parts.append("Generate the backend application:")
|
466 |
+
|
467 |
+
return "\n".join(input_parts)
|
468 |
+
|
469 |
+
def _create_output_text(self, example: CodeExample) -> str:
|
470 |
+
"""Create model output text"""
|
471 |
+
output_parts: List[str] = []
|
472 |
+
|
473 |
+
output_parts.append("Project Structure:")
|
474 |
+
for directory in example.project_structure.get('directories', []):
|
475 |
+
output_parts.append(f"/{directory}/")
|
476 |
+
|
477 |
+
output_parts.append("\nGenerated Files:")
|
478 |
+
|
479 |
+
for filename, content in example.code_files.items():
|
480 |
+
output_parts.append(f"\n--- {filename} ---")
|
481 |
+
output_parts.append(content)
|
482 |
+
output_parts.append("--- End ---")
|
483 |
+
|
484 |
+
return "\n".join(output_parts)
|
485 |
+
|
486 |
+
def create_training_dataset(self, processed_examples: List[Dict[str, str]]) -> HFDataset:
|
487 |
+
"""Create Hugging Face dataset for training"""
|
488 |
+
|
489 |
+
def tokenize_function(examples: Dict[str, List[str]]):
|
490 |
+
texts: List[str] = []
|
491 |
+
for inp, out in zip(examples['input'], examples['output']):
|
492 |
+
text = f"<|startoftext|>{inp}<|separator|>{out}<|endoftext|>"
|
493 |
+
texts.append(text)
|
494 |
+
|
495 |
+
return self.tokenizer(
|
496 |
+
texts,
|
497 |
+
truncation=True,
|
498 |
+
padding=True,
|
499 |
+
max_length=self.max_length
|
500 |
+
)
|
501 |
+
|
502 |
+
dataset_dict = {
|
503 |
+
'input': [ex['input'] for ex in processed_examples],
|
504 |
+
'output': [ex['output'] for ex in processed_examples],
|
505 |
+
'framework': [ex['framework'] for ex in processed_examples],
|
506 |
+
'language': [ex['language'] for ex in processed_examples]
|
507 |
+
}
|
508 |
+
|
509 |
+
dataset = HFDataset.from_dict(dataset_dict)
|
510 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
511 |
+
|
512 |
+
return tokenized_dataset
|
513 |
+
|
514 |
+
|
515 |
+
class CodeGenerationModel:
|
516 |
+
"""Custom model for backend code generation"""
|
517 |
+
|
518 |
+
def __init__(self, base_model: str = "microsoft/DialoGPT-medium"):
|
519 |
+
self.base_model = base_model
|
520 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
|
521 |
+
self.model = AutoModelForCausalLM.from_pretrained(base_model)
|
522 |
+
|
523 |
+
if self.tokenizer.pad_token is None:
|
524 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
525 |
+
|
526 |
+
def fine_tune(self, dataset: HFDataset, output_dir: str = "./trained_model"):
|
527 |
+
"""Fine-tune the model on backend code generation"""
|
528 |
+
logger.info("Starting model fine-tuning...")
|
529 |
+
|
530 |
+
training_args = TrainingArguments(
|
531 |
+
output_dir=output_dir,
|
532 |
+
overwrite_output_dir=True,
|
533 |
+
num_train_epochs=1, # Reduced from 3
|
534 |
+
per_device_train_batch_size=1, # Reduced from 2 for memory
|
535 |
+
per_device_eval_batch_size=1, # Reduced from 2
|
536 |
+
warmup_steps=50, # Reduced from 500
|
537 |
+
max_steps=100, # Drastically reduced from 2000
|
538 |
+
logging_steps=10, # More frequent logging
|
539 |
+
save_steps=50, # More frequent saves
|
540 |
+
save_total_limit=2,
|
541 |
+
prediction_loss_only=True,
|
542 |
+
fp16=torch.cuda.is_available(),
|
543 |
+
dataloader_pin_memory=False,
|
544 |
+
gradient_accumulation_steps=4, # Accumulate gradients for effective larger batch
|
545 |
+
learning_rate=5e-5, # Explicit learning rate
|
546 |
+
)
|
547 |
+
|
548 |
+
data_collator = DataCollatorForLanguageModeling(
|
549 |
+
tokenizer=self.tokenizer,
|
550 |
+
mlm=False,
|
551 |
+
)
|
552 |
+
|
553 |
+
train_size = int(0.8 * len(dataset))
|
554 |
+
eval_size = len(dataset) - train_size
|
555 |
+
train_dataset, eval_dataset = torch.utils.data.random_split(
|
556 |
+
dataset, [train_size, eval_size]
|
557 |
+
)
|
558 |
+
|
559 |
+
trainer = Trainer(
|
560 |
+
model=self.model,
|
561 |
+
args=training_args,
|
562 |
+
data_collator=data_collator,
|
563 |
+
train_dataset=train_dataset,
|
564 |
+
eval_dataset=eval_dataset,
|
565 |
+
)
|
566 |
+
|
567 |
+
trainer.train()
|
568 |
+
trainer.save_model()
|
569 |
+
|
570 |
+
logger.info("Fine-tuning completed!")
|
571 |
+
|
572 |
+
def generate_code(self, description: str, framework: str, language: str) -> str:
|
573 |
+
"""Generate backend code for given requirements"""
|
574 |
+
input_text = (
|
575 |
+
f"Description: {description}\n"
|
576 |
+
f"Framework: {framework}\n"
|
577 |
+
f"Language: {language}\n"
|
578 |
+
f"Generate the backend application:"
|
579 |
+
)
|
580 |
+
|
581 |
+
# Respect model's max position embeddings (GPT-2/DialoGPT is typically 1024)
|
582 |
+
model_max_len = getattr(self.tokenizer, 'model_max_length', 1024)
|
583 |
+
max_len = 1024 if model_max_len is None or model_max_len > 100000 else min(1024, int(model_max_len))
|
584 |
+
|
585 |
+
inputs = self.tokenizer.encode(input_text, return_tensors='pt', truncation=True, max_length=max_len)
|
586 |
+
|
587 |
+
with torch.no_grad():
|
588 |
+
outputs = self.model.generate(
|
589 |
+
inputs,
|
590 |
+
max_length=max_len,
|
591 |
+
num_return_sequences=1,
|
592 |
+
temperature=0.7,
|
593 |
+
do_sample=True,
|
594 |
+
pad_token_id=self.tokenizer.eos_token_id
|
595 |
+
)
|
596 |
+
|
597 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
598 |
+
return generated_text[len(input_text):]
|
599 |
+
|
600 |
+
|
601 |
+
class ModelEvaluator:
|
602 |
+
"""Evaluates model performance"""
|
603 |
+
|
604 |
+
def __init__(self):
|
605 |
+
self.metrics: Dict[str, float] = {}
|
606 |
+
|
607 |
+
def evaluate_code_quality(self, generated_code: str, language: str) -> Dict[str, float]:
|
608 |
+
"""Evaluate generated code quality"""
|
609 |
+
metrics = {
|
610 |
+
'syntax_correctness': self._check_syntax(generated_code, language),
|
611 |
+
'completeness': self._check_completeness(generated_code),
|
612 |
+
'best_practices': self._check_best_practices(generated_code, language)
|
613 |
+
}
|
614 |
+
|
615 |
+
return metrics
|
616 |
+
|
617 |
+
def _check_syntax(self, code: str, language: str) -> float:
|
618 |
+
"""Check if generated code has valid syntax"""
|
619 |
+
if language == 'python':
|
620 |
+
try:
|
621 |
+
ast.parse(code)
|
622 |
+
return 1.0
|
623 |
+
except SyntaxError:
|
624 |
+
return 0.0
|
625 |
+
elif language == 'javascript':
|
626 |
+
if '{' in code and '}' in code:
|
627 |
+
return 0.8
|
628 |
+
return 0.5
|
629 |
+
|
630 |
+
return 0.5
|
631 |
+
|
632 |
+
def _check_completeness(self, code: str) -> float:
|
633 |
+
"""Check if code appears complete"""
|
634 |
+
completeness_indicators = [
|
635 |
+
'import', 'require', 'function', 'def', 'class',
|
636 |
+
'app.', 'router.', '@app.', 'app.listen', 'if __name__'
|
637 |
+
]
|
638 |
+
|
639 |
+
indicators_found = sum(1 for indicator in completeness_indicators if indicator in code)
|
640 |
+
return min(indicators_found / 3.0, 1.0)
|
641 |
+
|
642 |
+
def _check_best_practices(self, code: str, language: str) -> float:
|
643 |
+
"""Check adherence to best practices"""
|
644 |
+
best_practices_score = 0.0
|
645 |
+
|
646 |
+
if 'try:' in code or 'catch' in code:
|
647 |
+
best_practices_score += 0.2
|
648 |
+
|
649 |
+
if any(comment in code for comment in ['#', '//', '/*']):
|
650 |
+
best_practices_score += 0.2
|
651 |
+
|
652 |
+
if language == 'python':
|
653 |
+
if 'if __name__ == "__main__"' in code:
|
654 |
+
best_practices_score += 0.2
|
655 |
+
elif language == 'javascript':
|
656 |
+
if 'const' in code or 'let' in code:
|
657 |
+
best_practices_score += 0.2
|
658 |
+
|
659 |
+
return min(best_practices_score, 1.0)
|
660 |
+
|
661 |
+
def benchmark_model(self, model: 'CodeGenerationModel', test_cases: List[Dict]) -> Dict[str, float]:
|
662 |
+
"""Benchmark model on test cases"""
|
663 |
+
total_scores = {'syntax': 0.0, 'completeness': 0.0, 'best_practices': 0.0}
|
664 |
+
|
665 |
+
for i, test_case in enumerate(test_cases):
|
666 |
+
generated_code = model.generate_code(
|
667 |
+
test_case['description'],
|
668 |
+
test_case['framework'],
|
669 |
+
test_case['language']
|
670 |
+
)
|
671 |
+
|
672 |
+
scores = self.evaluate_code_quality(generated_code, test_case['language'])
|
673 |
+
|
674 |
+
total_scores['syntax'] += scores['syntax_correctness']
|
675 |
+
total_scores['completeness'] += scores['completeness']
|
676 |
+
total_scores['best_practices'] += scores['best_practices']
|
677 |
+
|
678 |
+
logger.info(f"Test case {i+1}: {scores}")
|
679 |
+
|
680 |
+
num_cases = max(1, len(test_cases))
|
681 |
+
avg_scores = {key: value / num_cases for key, value in total_scores.items()}
|
682 |
+
|
683 |
+
return avg_scores
|
684 |
+
|
685 |
+
|
686 |
+
class TrainingPipeline:
|
687 |
+
"""Main training pipeline orchestrator"""
|
688 |
+
|
689 |
+
def __init__(self, config: Dict[str, Any]):
|
690 |
+
self.config = config
|
691 |
+
self.data_collector = DataCollector()
|
692 |
+
self.preprocessor = DataPreprocessor(config.get('tokenizer', 'microsoft/DialoGPT-medium'))
|
693 |
+
self.model = CodeGenerationModel(config.get('base_model', 'microsoft/DialoGPT-medium'))
|
694 |
+
self.evaluator = ModelEvaluator()
|
695 |
+
|
696 |
+
async def run_full_pipeline(self):
|
697 |
+
"""Run the complete training pipeline"""
|
698 |
+
logger.info("Starting full training pipeline...")
|
699 |
+
|
700 |
+
logger.info("Step 1: Collecting training data...")
|
701 |
+
|
702 |
+
if self.data_collector.github_token:
|
703 |
+
github_queries = [
|
704 |
+
'express api backend',
|
705 |
+
'fastapi python backend',
|
706 |
+
'django rest api',
|
707 |
+
'nodejs backend server',
|
708 |
+
'flask api backend'
|
709 |
+
]
|
710 |
+
await self.data_collector.collect_github_repositories(github_queries, max_repos=50)
|
711 |
+
|
712 |
+
self.data_collector.generate_synthetic_examples(count=200)
|
713 |
+
|
714 |
+
self.data_collector.save_dataset('raw_dataset.json')
|
715 |
+
|
716 |
+
logger.info("Step 2: Preprocessing data...")
|
717 |
+
processed_examples = self.preprocessor.preprocess_examples(self.data_collector.collected_examples)
|
718 |
+
training_dataset = self.preprocessor.create_training_dataset(processed_examples)
|
719 |
+
|
720 |
+
logger.info("Step 3: Training model...")
|
721 |
+
self.model.fine_tune(training_dataset, output_dir=self.config.get('output_dir', './trained_model'))
|
722 |
+
|
723 |
+
logger.info("Step 4: Evaluating model...")
|
724 |
+
test_cases = [
|
725 |
+
{
|
726 |
+
'description': 'REST API for user management with authentication',
|
727 |
+
'framework': 'express',
|
728 |
+
'language': 'javascript'
|
729 |
+
},
|
730 |
+
{
|
731 |
+
'description': 'FastAPI backend for e-commerce platform',
|
732 |
+
'framework': 'fastapi',
|
733 |
+
'language': 'python'
|
734 |
+
},
|
735 |
+
{
|
736 |
+
'description': 'Django REST API for blog platform',
|
737 |
+
'framework': 'django',
|
738 |
+
'language': 'python'
|
739 |
+
}
|
740 |
+
]
|
741 |
+
|
742 |
+
benchmark_results = self.evaluator.benchmark_model(self.model, test_cases)
|
743 |
+
logger.info(f"Benchmark results: {benchmark_results}")
|
744 |
+
|
745 |
+
logger.info("Training pipeline completed!")
|
746 |
+
return benchmark_results
|
747 |
+
|
748 |
+
|
749 |
+
if __name__ == "__main__":
|
750 |
+
config = {
|
751 |
+
'base_model': 'microsoft/DialoGPT-medium',
|
752 |
+
'tokenizer': 'microsoft/DialoGPT-medium',
|
753 |
+
'output_dir': './backend_code_model',
|
754 |
+
'github_token': os.getenv('GITHUB_TOKEN'),
|
755 |
+
}
|
756 |
+
|
757 |
+
pipeline = TrainingPipeline(config)
|
758 |
+
|
759 |
+
asyncio.run(pipeline.run_full_pipeline())
|
760 |
+
|
761 |
+
logger.info("\nTesting trained model...")
|
762 |
+
generated_code = pipeline.model.generate_code(
|
763 |
+
description="Create a REST API for managing tasks with CRUD operations",
|
764 |
+
framework="express",
|
765 |
+
language="javascript"
|
766 |
+
)
|
767 |
+
|
768 |
+
print("\nGenerated Code:")
|
769 |
+
print("=" * 50)
|
770 |
+
print(generated_code)
|
771 |
+
|
772 |
+
|