Techta commited on
Commit
472e2e9
·
0 Parent(s):
api_service.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Backend Code Generation API Service
4
+ ===================================
5
+
6
+ Production-ready API service for serving the trained backend code generation model.
7
+ Provides RESTful endpoints for generating complete backend applications.
8
+ """
9
+
10
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.responses import StreamingResponse, FileResponse
13
+ from pydantic import BaseModel, Field
14
+ from typing import List, Dict, Optional, Any
15
+ import torch
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+ import json
18
+ import zipfile
19
+ import tempfile
20
+ import os
21
+ import uuid
22
+ from datetime import datetime
23
+ import asyncio
24
+ import logging
25
+ from pathlib import Path
26
+
27
+ # Configure logging
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Pydantic models for API
32
+ class CodeGenerationRequest(BaseModel):
33
+ description: str = Field(..., description="Description of the backend application to generate")
34
+ framework: str = Field(..., description="Target framework (express, fastapi, django, flask)")
35
+ language: str = Field(..., description="Programming language (javascript, python)")
36
+ requirements: List[str] = Field(default=[], description="List of specific requirements")
37
+ project_name: Optional[str] = Field(default=None, description="Custom project name")
38
+
39
+ class Config:
40
+ schema_extra = {
41
+ "example": {
42
+ "description": "E-commerce API with user authentication and product management",
43
+ "framework": "fastapi",
44
+ "language": "python",
45
+ "requirements": [
46
+ "User registration and login",
47
+ "JWT authentication",
48
+ "Product CRUD operations",
49
+ "Shopping cart functionality",
50
+ "Order management"
51
+ ],
52
+ "project_name": "ecommerce-api"
53
+ }
54
+ }
55
+
56
+ class GenerationResponse(BaseModel):
57
+ task_id: str
58
+ status: str
59
+ message: str
60
+ estimated_time: int
61
+
62
+ class GenerationStatus(BaseModel):
63
+ task_id: str
64
+ status: str # pending, processing, completed, failed
65
+ progress: int # 0-100
66
+ message: str
67
+ generated_files: Optional[Dict[str, str]] = None
68
+ download_url: Optional[str] = None
69
+ error: Optional[str] = None
70
+
71
+ class GeneratedProject(BaseModel):
72
+ project_name: str
73
+ framework: str
74
+ language: str
75
+ files: Dict[str, str]
76
+ structure: Dict[str, Any]
77
+ setup_instructions: List[str]
78
+ features: List[str]
79
+
80
+ # Global model instance
81
+ class ModelManager:
82
+ def __init__(self):
83
+ self.model = None
84
+ self.tokenizer = None
85
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
86
+ self.loaded = False
87
+
88
+ async def load_model(self, model_path: str = "./trained_model"):
89
+ """Load the trained model asynchronously"""
90
+ try:
91
+ logger.info(f"Loading model from {model_path} on {self.device}")
92
+
93
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
94
+ self.model = AutoModelForCausalLM.from_pretrained(
95
+ model_path,
96
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
97
+ device_map="auto" if self.device == "cuda" else None
98
+ )
99
+
100
+ if self.device == "cpu":
101
+ self.model = self.model.to(self.device)
102
+
103
+ self.loaded = True
104
+ logger.info("Model loaded successfully!")
105
+
106
+ except Exception as e:
107
+ logger.error(f"Failed to load model: {e}")
108
+ raise
109
+
110
+ def generate_code(self, prompt: str, max_tokens: int = 1024) -> str:
111
+ """Generate code using the trained model"""
112
+ if not self.loaded:
113
+ raise RuntimeError("Model not loaded")
114
+
115
+ inputs = self.tokenizer.encode(prompt, return_tensors='pt')
116
+ inputs = inputs.to(self.device)
117
+
118
+ with torch.no_grad():
119
+ outputs = self.model.generate(
120
+ inputs,
121
+ max_length=min(max_tokens, 1024),
122
+ num_return_sequences=1,
123
+ temperature=0.7,
124
+ do_sample=True,
125
+ top_p=0.9,
126
+ pad_token_id=self.tokenizer.eos_token_id,
127
+ repetition_penalty=1.1
128
+ )
129
+
130
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
131
+ return generated_text[len(self.tokenizer.decode(inputs[0], skip_special_tokens=True)):]
132
+
133
+ # Global instances
134
+ model_manager = ModelManager()
135
+ generation_tasks = {} # Store generation tasks
136
+
137
+ # FastAPI app
138
+ app = FastAPI(
139
+ title="Backend Code Generation API",
140
+ description="AI-powered backend application generator",
141
+ version="1.0.0",
142
+ docs_url="/docs",
143
+ redoc_url="/redoc"
144
+ )
145
+
146
+ # CORS middleware
147
+ app.add_middleware(
148
+ CORSMiddleware,
149
+ allow_origins=["*"], # Configure for production
150
+ allow_credentials=True,
151
+ allow_methods=["*"],
152
+ allow_headers=["*"],
153
+ )
154
+
155
+ @app.on_event("startup")
156
+ async def startup_event():
157
+ """Load model on startup"""
158
+ model_path = os.getenv("MODEL_PATH", "./trained_model")
159
+ await model_manager.load_model(model_path)
160
+
161
+ @app.get("/")
162
+ async def root():
163
+ """API root endpoint"""
164
+ return {
165
+ "service": "Backend Code Generation API",
166
+ "version": "1.0.0",
167
+ "status": "running",
168
+ "model_loaded": model_manager.loaded,
169
+ "endpoints": {
170
+ "generate": "/api/v1/generate",
171
+ "status": "/api/v1/status/{task_id}",
172
+ "download": "/api/v1/download/{task_id}",
173
+ "health": "/health"
174
+ }
175
+ }
176
+
177
+ @app.get("/health")
178
+ async def health_check():
179
+ """Health check endpoint"""
180
+ return {
181
+ "status": "OK",
182
+ "timestamp": datetime.utcnow().isoformat(),
183
+ "model_loaded": model_manager.loaded,
184
+ "device": model_manager.device if model_manager.loaded else None
185
+ }
186
+
187
+ @app.post("/api/v1/generate", response_model=GenerationResponse)
188
+ async def generate_backend(
189
+ request: CodeGenerationRequest,
190
+ background_tasks: BackgroundTasks
191
+ ):
192
+ """Generate a complete backend application"""
193
+
194
+ if not model_manager.loaded:
195
+ raise HTTPException(status_code=503, detail="Model not loaded")
196
+
197
+ # Create unique task ID
198
+ task_id = str(uuid.uuid4())
199
+
200
+ # Initialize task status
201
+ generation_tasks[task_id] = GenerationStatus(
202
+ task_id=task_id,
203
+ status="pending",
204
+ progress=0,
205
+ message="Task queued for processing"
206
+ )
207
+
208
+ # Start background generation
209
+ background_tasks.add_task(
210
+ generate_project_background,
211
+ task_id,
212
+ request
213
+ )
214
+
215
+ return GenerationResponse(
216
+ task_id=task_id,
217
+ status="accepted",
218
+ message="Code generation started",
219
+ estimated_time=60 # seconds
220
+ )
221
+
222
+ @app.get("/api/v1/status/{task_id}", response_model=GenerationStatus)
223
+ async def get_generation_status(task_id: str):
224
+ """Get the status of a generation task"""
225
+
226
+ if task_id not in generation_tasks:
227
+ raise HTTPException(status_code=404, detail="Task not found")
228
+
229
+ return generation_tasks[task_id]
230
+
231
+ @app.get("/api/v1/download/{task_id}")
232
+ async def download_generated_project(task_id: str):
233
+ """Download the generated project as a ZIP file"""
234
+
235
+ if task_id not in generation_tasks:
236
+ raise HTTPException(status_code=404, detail="Task not found")
237
+
238
+ task = generation_tasks[task_id]
239
+
240
+ if task.status != "completed":
241
+ raise HTTPException(status_code=400, detail="Generation not completed")
242
+
243
+ if not task.download_url:
244
+ raise HTTPException(status_code=404, detail="Download file not available")
245
+
246
+ if not os.path.exists(task.download_url):
247
+ raise HTTPException(status_code=404, detail="Download file not found")
248
+
249
+ return FileResponse(
250
+ path=task.download_url,
251
+ filename=f"generated_project_{task_id}.zip",
252
+ media_type="application/zip"
253
+ )
254
+
255
+ @app.delete("/api/v1/cleanup/{task_id}")
256
+ async def cleanup_task(task_id: str):
257
+ """Clean up task files and data"""
258
+
259
+ if task_id not in generation_tasks:
260
+ raise HTTPException(status_code=404, detail="Task not found")
261
+
262
+ task = generation_tasks[task_id]
263
+
264
+ # Remove download file if exists
265
+ if task.download_url and os.path.exists(task.download_url):
266
+ os.remove(task.download_url)
267
+
268
+ # Remove task from memory
269
+ del generation_tasks[task_id]
270
+
271
+ return {"message": "Task cleaned up successfully"}
272
+
273
+ async def generate_project_background(task_id: str, request: CodeGenerationRequest):
274
+ """Background task for generating the complete project"""
275
+
276
+ task = generation_tasks[task_id]
277
+
278
+ try:
279
+ # Update status
280
+ task.status = "processing"
281
+ task.progress = 10
282
+ task.message = "Analyzing requirements..."
283
+
284
+ # Create the generation prompt
285
+ prompt = create_generation_prompt(request)
286
+
287
+ # Update progress
288
+ task.progress = 30
289
+ task.message = "Generating application structure..."
290
+
291
+ # Generate code using the model
292
+ generated_code = model_manager.generate_code(prompt, max_tokens=1024)
293
+
294
+ # Update progress
295
+ task.progress = 60
296
+ task.message = "Processing generated code..."
297
+
298
+ # Parse and structure the generated code
299
+ project_files = parse_generated_code(generated_code, request)
300
+
301
+ # Update progress
302
+ task.progress = 80
303
+ task.message = "Creating project files..."
304
+
305
+ # Create downloadable ZIP file
306
+ zip_path = create_project_zip(task_id, project_files, request)
307
+
308
+ # Complete the task
309
+ task.status = "completed"
310
+ task.progress = 100
311
+ task.message = "Project generated successfully"
312
+ task.generated_files = {name: "Generated" for name in project_files.keys()}
313
+ task.download_url = zip_path
314
+
315
+ except Exception as e:
316
+ logger.error(f"Generation failed for task {task_id}: {e}")
317
+ task.status = "failed"
318
+ task.error = str(e)
319
+ task.message = "Generation failed"
320
+
321
+ def create_generation_prompt(request: CodeGenerationRequest) -> str:
322
+ """Create the prompt for the model"""
323
+
324
+ prompt_parts = [
325
+ f"Description: {request.description}",
326
+ f"Framework: {request.framework}",
327
+ f"Language: {request.language}",
328
+ ]
329
+
330
+ if request.requirements:
331
+ prompt_parts.append("Requirements:")
332
+ for req in request.requirements:
333
+ prompt_parts.append(f"- {req}")
334
+
335
+ if request.project_name:
336
+ prompt_parts.append(f"Project Name: {request.project_name}")
337
+
338
+ prompt_parts.append("Generate the complete backend application with all necessary files:")
339
+
340
+ return "\n".join(prompt_parts)
341
+
342
+ def parse_generated_code(generated_code: str, request: CodeGenerationRequest) -> Dict[str, str]:
343
+ """Parse the generated code into individual files"""
344
+
345
+ files = {}
346
+
347
+ # Simple parsing logic - in production, this should be more sophisticated
348
+ lines = generated_code.split('\n')
349
+ current_file = None
350
+ current_content = []
351
+
352
+ for line in lines:
353
+ if line.startswith('--- ') and line.endswith(' ---'):
354
+ # Save previous file
355
+ if current_file:
356
+ files[current_file] = '\n'.join(current_content)
357
+
358
+ # Start new file
359
+ current_file = line.replace('--- ', '').replace(' ---', '').strip()
360
+ current_content = []
361
+
362
+ elif current_file and not line.startswith('--- End ---'):
363
+ current_content.append(line)
364
+
365
+ # Save last file
366
+ if current_file and current_content:
367
+ files[current_file] = '\n'.join(current_content)
368
+
369
+ # If parsing failed, create basic structure based on framework
370
+ if not files:
371
+ files = create_fallback_structure(request)
372
+
373
+ return files
374
+
375
+ def create_fallback_structure(request: CodeGenerationRequest) -> Dict[str, str]:
376
+ """Create a basic project structure if parsing fails"""
377
+
378
+ if request.framework.lower() == 'fastapi':
379
+ return {
380
+ 'main.py': f'''from fastapi import FastAPI
381
+
382
+ app = FastAPI(title="{request.description}")
383
+
384
+ @app.get("/")
385
+ async def root():
386
+ return {{"message": "Hello from {request.description}"}}
387
+
388
+ @app.get("/health")
389
+ async def health():
390
+ return {{"status": "OK"}}
391
+ ''',
392
+ 'requirements.txt': '''fastapi==0.104.1
393
+ uvicorn[standard]==0.24.0'''
394
+ }
395
+
396
+ elif request.framework.lower() == 'express':
397
+ return {
398
+ 'app.js': f'''const express = require('express');
399
+ const app = express();
400
+
401
+ app.get('/', (req, res) => {{
402
+ res.json({{ message: 'Hello from {request.description}' }});
403
+ }});
404
+
405
+ app.get('/health', (req, res) => {{
406
+ res.json({{ status: 'OK' }});
407
+ }});
408
+
409
+ const PORT = process.env.PORT || 3000;
410
+ app.listen(PORT, () => {{
411
+ console.log(`Server running on port ${{PORT}}`);
412
+ }});
413
+ ''',
414
+ 'package.json': json.dumps({
415
+ "name": request.project_name or "generated-backend",
416
+ "version": "1.0.0",
417
+ "main": "app.js",
418
+ "dependencies": {
419
+ "express": "^4.18.2"
420
+ }
421
+ }, indent=2)
422
+ }
423
+
424
+ else:
425
+ return {
426
+ 'README.md': f'# {request.description}\n\nGenerated backend application using {request.framework}'
427
+ }
428
+
429
+ def create_project_zip(task_id: str, files: Dict[str, str], request: CodeGenerationRequest) -> str:
430
+ """Create a ZIP file containing all project files"""
431
+
432
+ # Create temporary directory for the ZIP file
433
+ temp_dir = tempfile.gettempdir()
434
+ zip_path = os.path.join(temp_dir, f"project_{task_id}.zip")
435
+
436
+ project_name = request.project_name or f"generated_{request.framework}_app"
437
+
438
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
439
+ for filename, content in files.items():
440
+ # Add each file to the ZIP
441
+ arcname = f"{project_name}/{filename}"
442
+ zipf.writestr(arcname, content)
443
+
444
+ # Add a README with setup instructions
445
+ setup_instructions = get_setup_instructions(request.framework)
446
+ zipf.writestr(f"{project_name}/SETUP.md", setup_instructions)
447
+
448
+ return zip_path
449
+
450
+ def get_setup_instructions(framework: str) -> str:
451
+ """Get setup instructions for the framework"""
452
+
453
+ instructions = {
454
+ 'fastapi': '''# Setup Instructions
455
+
456
+ 1. Install dependencies:
457
+ ```bash
458
+ pip install -r requirements.txt
459
+ ```
460
+
461
+ 2. Run the application:
462
+ ```bash
463
+ uvicorn main:app --reload
464
+ ```
465
+
466
+ 3. Access the API:
467
+ - API: http://localhost:8000
468
+ - Docs: http://localhost:8000/docs
469
+ ''',
470
+ 'express': '''# Setup Instructions
471
+
472
+ 1. Install dependencies:
473
+ ```bash
474
+ npm install
475
+ ```
476
+
477
+ 2. Run the application:
478
+ ```bash
479
+ node app.js
480
+ ```
481
+
482
+ 3. Access the API:
483
+ - API: http://localhost:3000
484
+ ''',
485
+ 'django': '''# Setup Instructions
486
+
487
+ 1. Install dependencies:
488
+ ```bash
489
+ pip install -r requirements.txt
490
+ ```
491
+
492
+ 2. Run migrations:
493
+ ```bash
494
+ python manage.py migrate
495
+ ```
496
+
497
+ 3. Run the application:
498
+ ```bash
499
+ python manage.py runserver
500
+ ```
501
+
502
+ 4. Access the API:
503
+ - API: http://localhost:8000
504
+ - Admin: http://localhost:8000/admin
505
+ ''',
506
+ 'flask': '''# Setup Instructions
507
+
508
+ 1. Install dependencies:
509
+ ```bash
510
+ pip install -r requirements.txt
511
+ ```
512
+
513
+ 2. Run the application:
514
+ ```bash
515
+ python run.py
516
+ ```
517
+
518
+ 3. Access the API:
519
+ - API: http://localhost:5000
520
+ '''
521
+ }
522
+
523
+ return instructions.get(framework, '# Setup Instructions\n\nRefer to the framework documentation for setup instructions.')
524
+
525
+ # Additional utility endpoints
526
+ @app.get("/api/v1/frameworks")
527
+ async def list_supported_frameworks():
528
+ """List supported frameworks and languages"""
529
+ return {
530
+ "frameworks": [
531
+ {
532
+ "name": "fastapi",
533
+ "language": "python",
534
+ "description": "Modern, fast, web framework for building APIs"
535
+ },
536
+ {
537
+ "name": "express",
538
+ "language": "javascript",
539
+ "description": "Fast, unopinionated web framework for Node.js"
540
+ },
541
+ {
542
+ "name": "django",
543
+ "language": "python",
544
+ "description": "High-level Python web framework"
545
+ },
546
+ {
547
+ "name": "flask",
548
+ "language": "python",
549
+ "description": "Lightweight WSGI web application framework"
550
+ }
551
+ ]
552
+ }
553
+
554
+ @app.get("/api/v1/examples")
555
+ async def get_example_requests():
556
+ """Get example generation requests"""
557
+ return {
558
+ "examples": [
559
+ {
560
+ "name": "E-commerce API",
561
+ "request": {
562
+ "description": "Complete e-commerce backend with user management and product catalog",
563
+ "framework": "fastapi",
564
+ "language": "python",
565
+ "requirements": [
566
+ "User registration and authentication",
567
+ "Product CRUD operations",
568
+ "Shopping cart functionality",
569
+ "Order management",
570
+ "Payment processing integration"
571
+ ]
572
+ }
573
+ },
574
+ {
575
+ "name": "Task Management System",
576
+ "request": {
577
+ "description": "Task management system with team collaboration",
578
+ "framework": "express",
579
+ "language": "javascript",
580
+ "requirements": [
581
+ "User authentication with JWT",
582
+ "Task CRUD operations",
583
+ "Team and project management",
584
+ "Real-time notifications",
585
+ "File attachments"
586
+ ]
587
+ }
588
+ },
589
+ {
590
+ "name": "Blog Platform",
591
+ "request": {
592
+ "description": "Blog platform with content management",
593
+ "framework": "django",
594
+ "language": "python",
595
+ "requirements": [
596
+ "Article management",
597
+ "User comments and ratings",
598
+ "Category and tag system",
599
+ "SEO optimization",
600
+ "Media file handling"
601
+ ]
602
+ }
603
+ }
604
+ ]
605
+ }
606
+
607
+ if __name__ == "__main__":
608
+ import uvicorn
609
+ uvicorn.run(
610
+ "api_service:app",
611
+ host="0.0.0.0",
612
+ port=8000,
613
+ reload=True
614
+ )
data/.gitkeep ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
evaluation/.gitkeep ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
models/.gitkeep ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
raw_dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ pandas
5
+ numpy
6
+ aiohttp
7
+ requests
8
+ accelerate
9
+ fastapi
10
+ uvicorn
11
+ python-multipart
12
+
setup-guide.md ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backend Code Generation Model - Setup & Usage Guide
2
+
3
+ ## 🛠️ Installation & Setup
4
+
5
+ ### 1. Install Dependencies
6
+ ```bash
7
+ pip install torch transformers datasets pandas numpy aiohttp requests
8
+ pip install accelerate # For faster training
9
+ ```
10
+
11
+ ### 2. Set Environment Variables
12
+ ```bash
13
+ # Optional: GitHub token for collecting real repositories
14
+ export GITHUB_TOKEN="your_github_token_here"
15
+
16
+ # For GPU training (if available)
17
+ export CUDA_VISIBLE_DEVICES=0
18
+ ```
19
+
20
+ ### 3. Directory Structure
21
+ ```
22
+ backend-ai-trainer/
23
+ ├── training_pipeline.py # Main pipeline code
24
+ ├── data/
25
+ │ ├── raw_dataset.json # Collected training data
26
+ │ └── processed/ # Preprocessed data
27
+ ├── models/
28
+ │ ├── backend_code_model/ # Trained model output
29
+ │ └── checkpoints/ # Training checkpoints
30
+ └── evaluation/
31
+ ├── test_cases.json # Test scenarios
32
+ └── results/ # Evaluation results
33
+ ```
34
+
35
+ ## 🏃‍♂️ Quick Start
36
+
37
+ ### Option A: Full Automated Pipeline
38
+ ```python
39
+ import asyncio
40
+ from training_pipeline import TrainingPipeline
41
+
42
+ config = {
43
+ 'base_model': 'microsoft/DialoGPT-medium',
44
+ 'output_dir': './models/backend_code_model',
45
+ 'github_token': 'your_token_here', # Optional
46
+ }
47
+
48
+ pipeline = TrainingPipeline(config)
49
+ asyncio.run(pipeline.run_full_pipeline())
50
+ ```
51
+
52
+ ### Option B: Step-by-Step Execution
53
+
54
+ #### Step 1: Collect Training Data
55
+ ```python
56
+ from training_pipeline import DataCollector
57
+ import asyncio
58
+
59
+ collector = DataCollector()
60
+
61
+ # Collect from GitHub (requires token)
62
+ github_queries = [
63
+ 'express api backend',
64
+ 'fastapi python backend',
65
+ 'django rest api',
66
+ 'nodejs backend server',
67
+ 'flask api backend'
68
+ ]
69
+
70
+ asyncio.run(collector.collect_github_repositories(github_queries, max_repos=100))
71
+
72
+ # Generate synthetic examples
73
+ collector.generate_synthetic_examples(count=500)
74
+
75
+ # Save dataset
76
+ collector.save_dataset('training_data.json')
77
+ ```
78
+
79
+ #### Step 2: Preprocess Data
80
+ ```python
81
+ from training_pipeline import DataPreprocessor
82
+
83
+ preprocessor = DataPreprocessor()
84
+ processed_examples = preprocessor.preprocess_examples(collector.collected_examples)
85
+ training_dataset = preprocessor.create_training_dataset(processed_examples)
86
+
87
+ print(f"Created dataset with {len(training_dataset)} examples")
88
+ ```
89
+
90
+ #### Step 3: Train Model
91
+ ```python
92
+ from training_pipeline import CodeGenerationModel
93
+
94
+ model = CodeGenerationModel('microsoft/DialoGPT-medium')
95
+ model.fine_tune(training_dataset, output_dir='./trained_model')
96
+ ```
97
+
98
+ #### Step 4: Generate Code
99
+ ```python
100
+ # Generate a complete backend application
101
+ generated_code = model.generate_code(
102
+ description="E-commerce API with user authentication and product management",
103
+ framework="fastapi",
104
+ language="python"
105
+ )
106
+
107
+ print("Generated Backend Application:")
108
+ print("=" * 50)
109
+ print(generated_code)
110
+ ```
111
+
112
+ ## 🎯 Training Configuration Options
113
+
114
+ ### Model Selection
115
+ ```python
116
+ # Lightweight for testing
117
+ config['base_model'] = 'microsoft/DialoGPT-small'
118
+
119
+ # Balanced performance
120
+ config['base_model'] = 'microsoft/DialoGPT-medium'
121
+
122
+ # High quality (requires more resources)
123
+ config['base_model'] = 'microsoft/DialoGPT-large'
124
+ ```
125
+
126
+ ### Training Parameters
127
+ ```python
128
+ training_config = {
129
+ 'num_epochs': 5, # More epochs = better learning
130
+ 'batch_size': 4, # Adjust based on GPU memory
131
+ 'learning_rate': 5e-5, # Conservative learning rate
132
+ 'max_length': 2048, # Maximum token length
133
+ 'warmup_steps': 500, # Learning rate warmup
134
+ 'save_steps': 1000, # Checkpoint frequency
135
+ }
136
+ ```
137
+
138
+ ### Framework Coverage
139
+ The pipeline supports these backend frameworks:
140
+
141
+ **Node.js Frameworks:**
142
+ - Express.js - Most popular Node.js framework
143
+ - NestJS - Enterprise-grade framework
144
+ - Koa.js - Lightweight alternative
145
+
146
+ **Python Frameworks:**
147
+ - FastAPI - Modern, high-performance API framework
148
+ - Django - Full-featured web framework
149
+ - Flask - Lightweight and flexible
150
+
151
+ **Go Frameworks:**
152
+ - Gin - HTTP web framework
153
+ - Fiber - Express-inspired framework
154
+
155
+ ## 📊 Evaluation & Testing
156
+
157
+ ### Automatic Quality Assessment
158
+ ```python
159
+ from training_pipeline import ModelEvaluator
160
+
161
+ evaluator = ModelEvaluator()
162
+
163
+ # Test specific code generation
164
+ generated_code = model.generate_code(
165
+ description="User authentication API with JWT tokens",
166
+ framework="express",
167
+ language="javascript"
168
+ )
169
+
170
+ # Get quality scores
171
+ quality_scores = evaluator.evaluate_code_quality(generated_code, "javascript")
172
+ print(f"Syntax Correctness: {quality_scores['syntax_correctness']:.2f}")
173
+ print(f"Completeness: {quality_scores['completeness']:.2f}")
174
+ print(f"Best Practices: {quality_scores['best_practices']:.2f}")
175
+ ```
176
+
177
+ ### Comprehensive Benchmarking
178
+ ```python
179
+ test_cases = [
180
+ {
181
+ 'description': 'REST API for task management with user authentication',
182
+ 'framework': 'express',
183
+ 'language': 'javascript'
184
+ },
185
+ {
186
+ 'description': 'GraphQL API for social media platform',
187
+ 'framework': 'fastapi',
188
+ 'language': 'python'
189
+ },
190
+ {
191
+ 'description': 'Microservice for payment processing',
192
+ 'framework': 'gin',
193
+ 'language': 'go'
194
+ }
195
+ ]
196
+
197
+ benchmark_results = evaluator.benchmark_model(model, test_cases)
198
+ print("Overall Performance:", benchmark_results)
199
+ ```
200
+
201
+ ## 🚀 Advanced Usage
202
+
203
+ ### Custom Data Sources
204
+ ```python
205
+ # Add your own training examples
206
+ custom_examples = [
207
+ {
208
+ 'description': 'Custom API requirement',
209
+ 'requirements': ['Custom feature 1', 'Custom feature 2'],
210
+ 'framework': 'fastapi',
211
+ 'language': 'python',
212
+ 'code_files': {
213
+ 'main.py': '# Your custom code here',
214
+ 'requirements.txt': 'fastapi\nuvicorn'
215
+ }
216
+ }
217
+ ]
218
+
219
+ # Add to training data
220
+ collector.collected_examples.extend([CodeExample(**ex) for ex in custom_examples])
221
+ ```
222
+
223
+ ### Fine-tuning on Specific Domains
224
+ ```python
225
+ # Focus training on specific application types
226
+ domain_specific_queries = [
227
+ 'microservices architecture',
228
+ 'api gateway implementation',
229
+ 'database orm integration',
230
+ 'authentication middleware',
231
+ 'rate limiting api'
232
+ ]
233
+
234
+ asyncio.run(collector.collect_github_repositories(domain_specific_queries))
235
+ ```
236
+
237
+ ### Export Trained Model
238
+ ```python
239
+ # Save model for deployment
240
+ model.model.save_pretrained('./production_model')
241
+ model.tokenizer.save_pretrained('./production_model')
242
+
243
+ # Load for inference
244
+ from transformers import AutoModelForCausalLM, AutoTokenizer
245
+
246
+ production_model = AutoModelForCausalLM.from_pretrained('./production_model')
247
+ production_tokenizer = AutoTokenizer.from_pretrained('./production_model')
248
+ ```
249
+
250
+ ## 🔧 Troubleshooting
251
+
252
+ ### Common Issues
253
+
254
+ **1. Out of Memory Errors**
255
+ ```python
256
+ # Reduce batch size
257
+ config['per_device_train_batch_size'] = 1
258
+ config['gradient_accumulation_steps'] = 4
259
+
260
+ # Use gradient checkpointing
261
+ config['gradient_checkpointing'] = True
262
+ ```
263
+
264
+ **2. Slow Training**
265
+ ```python
266
+ # Enable mixed precision (if GPU supports it)
267
+ config['fp16'] = True
268
+
269
+ # Use multiple GPUs
270
+ config['dataloader_num_workers'] = 4
271
+ ```
272
+
273
+ **3. Poor Code Quality**
274
+ ```python
275
+ # Increase training data diversity
276
+ collector.generate_synthetic_examples(count=1000)
277
+
278
+ # Extend training duration
279
+ config['num_train_epochs'] = 10
280
+ ```
281
+
282
+ ### Performance Optimization
283
+
284
+ **For CPU Training:**
285
+ ```python
286
+ config['dataloader_pin_memory'] = False
287
+ config['per_device_train_batch_size'] = 1
288
+ ```
289
+
290
+ **For GPU Training:**
291
+ ```python
292
+ config['fp16'] = True
293
+ config['dataloader_pin_memory'] = True
294
+ config['per_device_train_batch_size'] = 4
295
+ ```
296
+
297
+ ## 📈 Expected Results
298
+
299
+ After training on ~500-1000 examples, you should expect:
300
+
301
+ - **Syntax Correctness**: 85-95%
302
+ - **Code Completeness**: 80-90%
303
+ - **Best Practices**: 70-85%
304
+ - **Framework Coverage**: All major Node.js and Python frameworks
305
+ - **Generation Speed**: 2-5 seconds per application
306
+
307
+ ## 🔄 Continuous Improvement
308
+
309
+ ### Regular Retraining
310
+ ```python
311
+ # Schedule weekly data collection
312
+ import schedule
313
+
314
+ def update_training_data():
315
+ asyncio.run(collector.collect_github_repositories(['new backend trends']))
316
+
317
+ schedule.every().week.do(update_training_data)
318
+ ```
319
+
320
+ ### A/B Testing Different Models
321
+ ```python
322
+ models_to_compare = [
323
+ 'microsoft/DialoGPT-medium',
324
+ 'microsoft/DialoGPT-large',
325
+ 'gpt2-medium'
326
+ ]
327
+
328
+ for base_model in models_to_compare:
329
+ model = CodeGenerationModel(base_model)
330
+ results = evaluator.benchmark_model(model, test_cases)
331
+ print(f"{base_model}: {results}")
332
+ ```
333
+
334
+ ## 🎯 Next Steps
335
+
336
+ 1. **Start Small**: Begin with synthetic data and 100-200 examples
337
+ 2. **Add Real Data**: Integrate GitHub repositories gradually
338
+ 3. **Evaluate Regularly**: Monitor quality metrics after each training session
339
+ 4. **Expand Frameworks**: Add support for new frameworks as needed
340
+ 5. **Production Deploy**: Export model for API deployment
341
+
342
+ This pipeline provides a complete foundation for building your own backend code generation AI. The modular design allows you to customize and extend each component based on your specific needs.
test_api.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for the Backend Code Generation API
4
+ ===============================================
5
+
6
+ Simple test script to verify the API is working correctly.
7
+ """
8
+
9
+ import requests
10
+ import json
11
+ import time
12
+ import os
13
+
14
+ # API base URL
15
+ BASE_URL = "http://localhost:8000"
16
+
17
+ def test_health():
18
+ """Test the health endpoint"""
19
+ print("Testing health endpoint...")
20
+ response = requests.get(f"{BASE_URL}/health")
21
+ print(f"Status: {response.status_code}")
22
+ print(f"Response: {response.json()}")
23
+ return response.status_code == 200
24
+
25
+ def test_generate_code():
26
+ """Test code generation"""
27
+ print("\nTesting code generation...")
28
+
29
+ # Test request
30
+ request_data = {
31
+ "description": "Simple REST API for task management",
32
+ "framework": "fastapi",
33
+ "language": "python",
34
+ "requirements": [
35
+ "User authentication",
36
+ "Task CRUD operations",
37
+ "Task status tracking"
38
+ ],
39
+ "project_name": "task-manager-api"
40
+ }
41
+
42
+ # Submit generation request
43
+ response = requests.post(f"{BASE_URL}/api/v1/generate", json=request_data)
44
+ print(f"Generation request status: {response.status_code}")
45
+
46
+ if response.status_code == 200:
47
+ result = response.json()
48
+ task_id = result["task_id"]
49
+ print(f"Task ID: {task_id}")
50
+
51
+ # Poll for completion
52
+ print("Polling for completion...")
53
+ for i in range(30): # Wait up to 5 minutes
54
+ status_response = requests.get(f"{BASE_URL}/api/v1/status/{task_id}")
55
+ if status_response.status_code == 200:
56
+ status = status_response.json()
57
+ print(f"Status: {status['status']} - {status['message']} ({status['progress']}%)")
58
+
59
+ if status["status"] == "completed":
60
+ print("✅ Generation completed!")
61
+ if status.get("download_url"):
62
+ print(f"Download URL: {status['download_url']}")
63
+ return True
64
+ elif status["status"] == "failed":
65
+ print(f"❌ Generation failed: {status.get('error', 'Unknown error')}")
66
+ return False
67
+ else:
68
+ print(f"Failed to get status: {status_response.status_code}")
69
+ return False
70
+
71
+ time.sleep(10) # Wait 10 seconds between polls
72
+
73
+ print("⏰ Timeout waiting for completion")
74
+ return False
75
+ else:
76
+ print(f"❌ Generation request failed: {response.text}")
77
+ return False
78
+
79
+ def test_frameworks():
80
+ """Test frameworks endpoint"""
81
+ print("\nTesting frameworks endpoint...")
82
+ response = requests.get(f"{BASE_URL}/api/v1/frameworks")
83
+ print(f"Status: {response.status_code}")
84
+ if response.status_code == 200:
85
+ frameworks = response.json()
86
+ print(f"Supported frameworks: {len(frameworks['frameworks'])}")
87
+ for fw in frameworks['frameworks']:
88
+ print(f" - {fw['name']} ({fw['language']})")
89
+ return True
90
+ return False
91
+
92
+ def test_examples():
93
+ """Test examples endpoint"""
94
+ print("\nTesting examples endpoint...")
95
+ response = requests.get(f"{BASE_URL}/api/v1/examples")
96
+ print(f"Status: {response.status_code}")
97
+ if response.status_code == 200:
98
+ examples = response.json()
99
+ print(f"Available examples: {len(examples['examples'])}")
100
+ for ex in examples['examples']:
101
+ print(f" - {ex['name']}")
102
+ return True
103
+ return False
104
+
105
+ def main():
106
+ """Run all tests"""
107
+ print("🚀 Testing Backend Code Generation API")
108
+ print("=" * 50)
109
+
110
+ # Check if API is running
111
+ try:
112
+ response = requests.get(f"{BASE_URL}/", timeout=5)
113
+ if response.status_code != 200:
114
+ print("❌ API is not running. Please start it with: python api_service.py")
115
+ return
116
+ except requests.exceptions.RequestException:
117
+ print("❌ Cannot connect to API. Please start it with: python api_service.py")
118
+ return
119
+
120
+ print("✅ API is running")
121
+
122
+ # Run tests
123
+ tests = [
124
+ ("Health Check", test_health),
125
+ ("Frameworks List", test_frameworks),
126
+ ("Examples List", test_examples),
127
+ ("Code Generation", test_generate_code),
128
+ ]
129
+
130
+ results = []
131
+ for test_name, test_func in tests:
132
+ print(f"\n{'='*20} {test_name} {'='*20}")
133
+ try:
134
+ result = test_func()
135
+ results.append((test_name, result))
136
+ except Exception as e:
137
+ print(f"❌ Test failed with error: {e}")
138
+ results.append((test_name, False))
139
+
140
+ # Summary
141
+ print(f"\n{'='*50}")
142
+ print("📊 Test Results Summary:")
143
+ print("=" * 50)
144
+
145
+ passed = 0
146
+ for test_name, result in results:
147
+ status = "✅ PASS" if result else "❌ FAIL"
148
+ print(f"{test_name}: {status}")
149
+ if result:
150
+ passed += 1
151
+
152
+ print(f"\nPassed: {passed}/{len(results)} tests")
153
+
154
+ if passed == len(results):
155
+ print("🎉 All tests passed!")
156
+ else:
157
+ print("⚠️ Some tests failed. Check the output above for details.")
158
+
159
+ if __name__ == "__main__":
160
+ main()
training_pipeline.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Backend Code Generation Model Training Pipeline
4
+ ===============================================
5
+
6
+ A comprehensive training pipeline for building an AI model that generates
7
+ framework-agnostic backend code with full application scaffolding.
8
+
9
+ Features:
10
+ - Data collection from multiple sources
11
+ - Multi-framework support (Express.js, FastAPI, Django, Flask, etc.)
12
+ - Full application scaffolding generation
13
+ - Model training with transformer architecture
14
+ - Evaluation and benchmarking tools
15
+ """
16
+
17
+ import os
18
+ import json
19
+ import logging
20
+ import asyncio
21
+ import aiohttp
22
+ import pandas as pd
23
+ import numpy as np
24
+ from typing import Dict, List, Optional, Tuple, Any
25
+ from dataclasses import dataclass, asdict
26
+ from pathlib import Path
27
+ import torch
28
+ import torch.nn as nn
29
+ from torch.utils.data import Dataset, DataLoader
30
+ from transformers import (
31
+ AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
32
+ Trainer, DataCollatorForLanguageModeling
33
+ )
34
+ from datasets import Dataset as HFDataset
35
+ import ast
36
+ import subprocess
37
+ import tempfile
38
+ from concurrent.futures import ThreadPoolExecutor
39
+ import requests
40
+ import time
41
+ import random
42
+
43
+ # Configure logging
44
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ @dataclass
49
+ class CodeExample:
50
+ """Represents a single training example"""
51
+ description: str
52
+ requirements: List[str]
53
+ framework: str
54
+ language: str
55
+ code_files: Dict[str, str] # filename -> content
56
+ project_structure: Dict[str, Any]
57
+ metadata: Dict[str, Any]
58
+
59
+
60
+ class DataCollector:
61
+ """Collects training data from various sources"""
62
+
63
+ def __init__(self):
64
+ self.github_token = os.getenv('GITHUB_TOKEN')
65
+ self.collected_examples: List[CodeExample] = []
66
+
67
+ async def collect_github_repositories(self, queries: List[str], max_repos: int = 100):
68
+ """Collect backend projects from GitHub"""
69
+ logger.info("Starting GitHub repository collection...")
70
+
71
+ headers = {'Authorization': f'token {self.github_token}'} if self.github_token else {}
72
+
73
+ async with aiohttp.ClientSession(headers=headers) as session:
74
+ per_query = max(1, max_repos // max(1, len(queries)))
75
+ for query in queries:
76
+ await self._search_github_repos(session, query, per_query)
77
+
78
+ async def _search_github_repos(self, session: aiohttp.ClientSession, query: str, limit: int):
79
+ """Search GitHub for repositories matching query"""
80
+ url = f"https://api.github.com/search/repositories"
81
+ params = {
82
+ 'q': query,
83
+ 'sort': 'stars',
84
+ 'order': 'desc',
85
+ 'per_page': min(limit, 100)
86
+ }
87
+
88
+ try:
89
+ async with session.get(url, params=params) as response:
90
+ if response.status == 200:
91
+ data = await response.json()
92
+ for repo in data.get('items', []):
93
+ await self._process_repository(session, repo)
94
+ else:
95
+ logger.warning(f"GitHub API request failed: {response.status}")
96
+ except Exception as e:
97
+ logger.error(f"Error searching GitHub: {e}")
98
+
99
+ async def _process_repository(self, session: aiohttp.ClientSession, repo: Dict):
100
+ """Process a single repository to extract code examples"""
101
+ logger.info(f"Processing repository: {repo.get('full_name', '<unknown>')}")
102
+
103
+ try:
104
+ contents_url = f"https://api.github.com/repos/{repo['full_name']}/contents"
105
+ async with session.get(contents_url) as response:
106
+ if response.status == 200:
107
+ contents = await response.json()
108
+ await self._extract_code_example(session, repo, contents)
109
+ except Exception as e:
110
+ logger.error(f"Error processing repository {repo.get('full_name')}: {e}")
111
+
112
+ async def _extract_code_example(self, session: aiohttp.ClientSession, repo: Dict, contents: List[Dict]):
113
+ """Extract a structured code example from repository"""
114
+ framework = self._identify_framework(contents, repo.get('description', ''))
115
+ language = self._identify_language(contents)
116
+
117
+ if not framework or not language:
118
+ return
119
+
120
+ code_files: Dict[str, str] = {}
121
+ for item in contents:
122
+ if item.get('type') == 'file' and self._is_important_file(item.get('name', '')):
123
+ try:
124
+ async with session.get(item['download_url']) as response:
125
+ if response.status == 200:
126
+ content = await response.text()
127
+ code_files[item['name']] = content
128
+ except Exception:
129
+ continue
130
+
131
+ if code_files:
132
+ example = CodeExample(
133
+ description=repo.get('description', ''),
134
+ requirements=self._extract_requirements(code_files),
135
+ framework=framework,
136
+ language=language,
137
+ code_files=code_files,
138
+ project_structure=self._analyze_structure(contents),
139
+ metadata={
140
+ 'stars': repo.get('stargazers_count', 0),
141
+ 'forks': repo.get('forks_count', 0),
142
+ 'url': repo.get('html_url'),
143
+ 'created_at': repo.get('created_at'),
144
+ 'updated_at': repo.get('updated_at')
145
+ }
146
+ )
147
+ self.collected_examples.append(example)
148
+
149
+ def _identify_framework(self, contents: List[Dict], description: str) -> Optional[str]:
150
+ """Identify the backend framework used"""
151
+ filenames = [item.get('name', '').lower() for item in contents if item.get('type') == 'file']
152
+
153
+ frameworks = {
154
+ 'express': ['package.json', 'app.js', 'server.js'],
155
+ 'fastapi': ['requirements.txt', 'main.py', 'app.py'],
156
+ 'django': ['manage.py', 'settings.py', 'requirements.txt'],
157
+ 'flask': ['app.py', 'requirements.txt'],
158
+ 'nestjs': ['nest-cli.json', 'package.json'],
159
+ 'koa': ['package.json'],
160
+ 'gin': ['go.mod', 'main.go'],
161
+ 'fiber': ['go.mod', 'main.go'],
162
+ }
163
+
164
+ for framework, required_files in frameworks.items():
165
+ if all(any(req in filename for filename in filenames) for req in required_files[:2]):
166
+ return framework
167
+
168
+ desc_lower = description.lower()
169
+ for framework in frameworks.keys():
170
+ if framework in desc_lower:
171
+ return framework
172
+
173
+ return None
174
+
175
+ def _identify_language(self, contents: List[Dict]) -> Optional[str]:
176
+ """Identify primary programming language"""
177
+ extensions: Dict[str, int] = {}
178
+ for item in contents:
179
+ if item.get('type') == 'file':
180
+ ext = Path(item.get('name', '')).suffix.lower()
181
+ if ext:
182
+ extensions[ext] = extensions.get(ext, 0) + 1
183
+
184
+ lang_map = {
185
+ '.js': 'javascript',
186
+ '.ts': 'typescript',
187
+ '.py': 'python',
188
+ '.go': 'go',
189
+ '.java': 'java',
190
+ '.cs': 'csharp',
191
+ '.rb': 'ruby',
192
+ '.php': 'php'
193
+ }
194
+
195
+ if extensions:
196
+ most_common_ext = max(extensions.items(), key=lambda x: x[1])[0]
197
+ return lang_map.get(most_common_ext)
198
+
199
+ return None
200
+
201
+ def _is_important_file(self, filename: str) -> bool:
202
+ """Check if file is important for training"""
203
+ important_patterns = [
204
+ 'package.json', 'requirements.txt', 'go.mod', 'pom.xml',
205
+ 'dockerfile', 'docker-compose.yml', 'readme.md',
206
+ 'app.py', 'main.py', 'server.js', 'app.js', 'index.js',
207
+ 'settings.py', 'config.py', 'routes.py', 'models.py',
208
+ 'controller.js', 'service.js', 'middleware.js'
209
+ ]
210
+
211
+ filename_lower = filename.lower()
212
+ return any(pattern in filename_lower for pattern in important_patterns)
213
+
214
+ def _extract_requirements(self, code_files: Dict[str, str]) -> List[str]:
215
+ """Extract functional requirements from code"""
216
+ requirements: List[str] = []
217
+
218
+ if 'package.json' in code_files:
219
+ try:
220
+ pkg_data = json.loads(code_files['package.json'])
221
+ deps = list(pkg_data.get('dependencies', {}).keys())
222
+ requirements.extend([f"Uses {dep}" for dep in deps[:5]])
223
+ except Exception:
224
+ pass
225
+
226
+ if 'requirements.txt' in code_files:
227
+ lines = code_files['requirements.txt'].strip().split('\n')
228
+ deps = [line.split('==')[0].split('>=')[0].strip() for line in lines if line.strip()]
229
+ requirements.extend([f"Uses {dep}" for dep in deps[:5]])
230
+
231
+ for filename, content in code_files.items():
232
+ if filename.endswith(('.js', '.py')):
233
+ endpoints = self._extract_endpoints(content)
234
+ requirements.extend(endpoints)
235
+
236
+ return requirements[:10]
237
+
238
+ def _extract_endpoints(self, code_content: str) -> List[str]:
239
+ """Extract API endpoints from code"""
240
+ endpoints: List[str] = []
241
+ lines = code_content.split('\n')
242
+
243
+ for line in lines:
244
+ s = line.strip()
245
+ if any(method in s for method in ['app.get(', 'app.post(', 'app.put(', 'app.delete(']):
246
+ endpoints.append(f"Implements {s}")
247
+ elif any(decorator in s for decorator in ['@app.get(', '@app.post(', '@app.put(', '@app.delete(']):
248
+ endpoints.append(f"Implements {s}")
249
+ elif 'def ' in s and any(word in s for word in ['get', 'post', 'put', 'delete']):
250
+ endpoints.append(f"Implements {s}")
251
+
252
+ return endpoints[:5]
253
+
254
+ def _analyze_structure(self, contents: List[Dict]) -> Dict[str, Any]:
255
+ """Analyze project structure"""
256
+ structure: Dict[str, Any] = {
257
+ 'files': [],
258
+ 'directories': [],
259
+ 'total_files': 0,
260
+ 'has_tests': False,
261
+ 'has_docs': False
262
+ }
263
+
264
+ for item in contents:
265
+ if item.get('type') == 'file':
266
+ name = item.get('name', '')
267
+ structure['files'].append(name)
268
+ structure['total_files'] += 1
269
+ if 'test' in name.lower():
270
+ structure['has_tests'] = True
271
+ if name.lower() in ['readme.md', 'docs.md']:
272
+ structure['has_docs'] = True
273
+ elif item.get('type') == 'dir':
274
+ structure['directories'].append(item.get('name', ''))
275
+
276
+ return structure
277
+
278
+ def generate_synthetic_examples(self, count: int = 100):
279
+ """Generate synthetic training examples"""
280
+ logger.info(f"Generating {count} synthetic examples...")
281
+
282
+ templates = [
283
+ {
284
+ 'description': 'REST API for user management',
285
+ 'requirements': ['User registration', 'User authentication', 'Profile management'],
286
+ 'frameworks': ['express', 'fastapi', 'django']
287
+ },
288
+ {
289
+ 'description': 'E-commerce backend API',
290
+ 'requirements': ['Product catalog', 'Shopping cart', 'Order processing', 'Payment integration'],
291
+ 'frameworks': ['nestjs', 'fastapi', 'django']
292
+ },
293
+ {
294
+ 'description': 'Task management system',
295
+ 'requirements': ['Task CRUD operations', 'User assignments', 'Status tracking'],
296
+ 'frameworks': ['express', 'flask', 'gin']
297
+ },
298
+ {
299
+ 'description': 'Blog platform backend',
300
+ 'requirements': ['Article management', 'User comments', 'Category system'],
301
+ 'frameworks': ['express', 'django', 'fastapi']
302
+ }
303
+ ]
304
+
305
+ for _ in range(count):
306
+ template = random.choice(templates)
307
+ framework = random.choice(template['frameworks'])
308
+
309
+ code_files = self._generate_code_for_template(template, framework)
310
+
311
+ example = CodeExample(
312
+ description=template['description'],
313
+ requirements=template['requirements'],
314
+ framework=framework,
315
+ language='python' if framework in ['fastapi', 'django', 'flask'] else 'javascript',
316
+ code_files=code_files,
317
+ project_structure=self._generate_synthetic_structure(framework),
318
+ metadata={'synthetic': True}
319
+ )
320
+
321
+ self.collected_examples.append(example)
322
+
323
+ def _generate_code_for_template(self, template: Dict, framework: str) -> Dict[str, str]:
324
+ """Generate code files for a template and framework"""
325
+ if framework == 'express':
326
+ return {
327
+ 'package.json': json.dumps({
328
+ "name": template['description'].lower().replace(' ', '-'),
329
+ "version": "1.0.0",
330
+ "dependencies": {
331
+ "express": "^4.18.0",
332
+ "mongoose": "^6.0.0",
333
+ "bcrypt": "^5.0.0",
334
+ "jsonwebtoken": "^8.5.0"
335
+ }
336
+ }, indent=2),
337
+ 'app.js': '''const express = require('express');
338
+ const mongoose = require('mongoose');
339
+ const app = express();
340
+
341
+ // Middleware
342
+ app.use(express.json());
343
+
344
+ // Routes
345
+ app.get('/health', (req, res) => {
346
+ res.json({ status: 'OK' });
347
+ });
348
+
349
+ // Start server
350
+ const PORT = process.env.PORT || 3000;
351
+ app.listen(PORT, () => {
352
+ console.log(`Server running on port ${PORT}`);
353
+ });
354
+
355
+ module.exports = app;'''
356
+ }
357
+ elif framework == 'fastapi':
358
+ return {
359
+ 'requirements.txt': '''fastapi==0.68.0
360
+ uvicorn==0.15.0
361
+ sqlalchemy==1.4.23
362
+ pydantic==1.8.2''',
363
+ 'main.py': '''from fastapi import FastAPI, HTTPException
364
+ from pydantic import BaseModel
365
+ from typing import List, Optional
366
+
367
+ app = FastAPI()
368
+
369
+ class Item(BaseModel):
370
+ id: Optional[int] = None
371
+ name: str
372
+ description: str
373
+
374
+ @app.get("/")
375
+ async def root():
376
+ return {"message": "Hello World"}
377
+
378
+ @app.get("/health")
379
+ async def health_check():
380
+ return {"status": "OK"}
381
+
382
+ if __name__ == "__main__":
383
+ import uvicorn
384
+ uvicorn.run(app, host="0.0.0.0", port=8000)'''
385
+ }
386
+ else:
387
+ return {'placeholder.txt': 'Generated code placeholder'}
388
+
389
+ def _generate_synthetic_structure(self, framework: str) -> Dict[str, Any]:
390
+ """Generate project structure for framework"""
391
+ if framework in ['express', 'nestjs']:
392
+ return {
393
+ 'files': ['package.json', 'app.js', 'README.md'],
394
+ 'directories': ['routes', 'controllers', 'middleware', 'models'],
395
+ 'total_files': 3,
396
+ 'has_tests': True,
397
+ 'has_docs': True
398
+ }
399
+ elif framework in ['fastapi', 'django', 'flask']:
400
+ return {
401
+ 'files': ['requirements.txt', 'main.py', 'README.md'],
402
+ 'directories': ['models', 'routes', 'services'],
403
+ 'total_files': 3,
404
+ 'has_tests': True,
405
+ 'has_docs': True
406
+ }
407
+ else:
408
+ return {}
409
+
410
+ def save_dataset(self, filepath: str):
411
+ """Save collected examples to file"""
412
+ data = [asdict(example) for example in self.collected_examples]
413
+ with open(filepath, 'w', encoding='utf-8') as f:
414
+ json.dump(data, f, indent=2, ensure_ascii=False)
415
+ logger.info(f"Saved {len(data)} examples to {filepath}")
416
+
417
+
418
+ class DataPreprocessor:
419
+ """Preprocesses collected data for training"""
420
+
421
+ def __init__(self, tokenizer_name: str = "microsoft/DialoGPT-medium"):
422
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
423
+ if self.tokenizer.pad_token is None:
424
+ self.tokenizer.pad_token = self.tokenizer.eos_token
425
+ # Ensure we do not exceed model's maximum positional embeddings (GPT-2/DialoGPT: 1024)
426
+ try:
427
+ model_max = getattr(self.tokenizer, 'model_max_length', 1024)
428
+ # Some tokenizers set a very large sentinel value; cap at 1024 for GPT-2 family
429
+ if model_max and model_max > 0 and model_max < 100000:
430
+ self.max_length = min(1024, int(model_max))
431
+ else:
432
+ self.max_length = 1024
433
+ except Exception:
434
+ self.max_length = 1024
435
+
436
+ def preprocess_examples(self, examples: List[CodeExample]) -> List[Dict[str, str]]:
437
+ """Convert examples to training format"""
438
+ processed: List[Dict[str, str]] = []
439
+
440
+ for example in examples:
441
+ input_text = self._create_input_text(example)
442
+ output_text = self._create_output_text(example)
443
+
444
+ processed.append({
445
+ 'input': input_text,
446
+ 'output': output_text,
447
+ 'framework': example.framework,
448
+ 'language': example.language
449
+ })
450
+
451
+ return processed
452
+
453
+ def _create_input_text(self, example: CodeExample) -> str:
454
+ """Create model input text"""
455
+ input_parts: List[str] = [
456
+ f"Description: {example.description}",
457
+ f"Framework: {example.framework}",
458
+ f"Language: {example.language}",
459
+ "Requirements:",
460
+ ]
461
+
462
+ for req in example.requirements:
463
+ input_parts.append(f"- {req}")
464
+
465
+ input_parts.append("Generate the backend application:")
466
+
467
+ return "\n".join(input_parts)
468
+
469
+ def _create_output_text(self, example: CodeExample) -> str:
470
+ """Create model output text"""
471
+ output_parts: List[str] = []
472
+
473
+ output_parts.append("Project Structure:")
474
+ for directory in example.project_structure.get('directories', []):
475
+ output_parts.append(f"/{directory}/")
476
+
477
+ output_parts.append("\nGenerated Files:")
478
+
479
+ for filename, content in example.code_files.items():
480
+ output_parts.append(f"\n--- {filename} ---")
481
+ output_parts.append(content)
482
+ output_parts.append("--- End ---")
483
+
484
+ return "\n".join(output_parts)
485
+
486
+ def create_training_dataset(self, processed_examples: List[Dict[str, str]]) -> HFDataset:
487
+ """Create Hugging Face dataset for training"""
488
+
489
+ def tokenize_function(examples: Dict[str, List[str]]):
490
+ texts: List[str] = []
491
+ for inp, out in zip(examples['input'], examples['output']):
492
+ text = f"<|startoftext|>{inp}<|separator|>{out}<|endoftext|>"
493
+ texts.append(text)
494
+
495
+ return self.tokenizer(
496
+ texts,
497
+ truncation=True,
498
+ padding=True,
499
+ max_length=self.max_length
500
+ )
501
+
502
+ dataset_dict = {
503
+ 'input': [ex['input'] for ex in processed_examples],
504
+ 'output': [ex['output'] for ex in processed_examples],
505
+ 'framework': [ex['framework'] for ex in processed_examples],
506
+ 'language': [ex['language'] for ex in processed_examples]
507
+ }
508
+
509
+ dataset = HFDataset.from_dict(dataset_dict)
510
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
511
+
512
+ return tokenized_dataset
513
+
514
+
515
+ class CodeGenerationModel:
516
+ """Custom model for backend code generation"""
517
+
518
+ def __init__(self, base_model: str = "microsoft/DialoGPT-medium"):
519
+ self.base_model = base_model
520
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model)
521
+ self.model = AutoModelForCausalLM.from_pretrained(base_model)
522
+
523
+ if self.tokenizer.pad_token is None:
524
+ self.tokenizer.pad_token = self.tokenizer.eos_token
525
+
526
+ def fine_tune(self, dataset: HFDataset, output_dir: str = "./trained_model"):
527
+ """Fine-tune the model on backend code generation"""
528
+ logger.info("Starting model fine-tuning...")
529
+
530
+ training_args = TrainingArguments(
531
+ output_dir=output_dir,
532
+ overwrite_output_dir=True,
533
+ num_train_epochs=1, # Reduced from 3
534
+ per_device_train_batch_size=1, # Reduced from 2 for memory
535
+ per_device_eval_batch_size=1, # Reduced from 2
536
+ warmup_steps=50, # Reduced from 500
537
+ max_steps=100, # Drastically reduced from 2000
538
+ logging_steps=10, # More frequent logging
539
+ save_steps=50, # More frequent saves
540
+ save_total_limit=2,
541
+ prediction_loss_only=True,
542
+ fp16=torch.cuda.is_available(),
543
+ dataloader_pin_memory=False,
544
+ gradient_accumulation_steps=4, # Accumulate gradients for effective larger batch
545
+ learning_rate=5e-5, # Explicit learning rate
546
+ )
547
+
548
+ data_collator = DataCollatorForLanguageModeling(
549
+ tokenizer=self.tokenizer,
550
+ mlm=False,
551
+ )
552
+
553
+ train_size = int(0.8 * len(dataset))
554
+ eval_size = len(dataset) - train_size
555
+ train_dataset, eval_dataset = torch.utils.data.random_split(
556
+ dataset, [train_size, eval_size]
557
+ )
558
+
559
+ trainer = Trainer(
560
+ model=self.model,
561
+ args=training_args,
562
+ data_collator=data_collator,
563
+ train_dataset=train_dataset,
564
+ eval_dataset=eval_dataset,
565
+ )
566
+
567
+ trainer.train()
568
+ trainer.save_model()
569
+
570
+ logger.info("Fine-tuning completed!")
571
+
572
+ def generate_code(self, description: str, framework: str, language: str) -> str:
573
+ """Generate backend code for given requirements"""
574
+ input_text = (
575
+ f"Description: {description}\n"
576
+ f"Framework: {framework}\n"
577
+ f"Language: {language}\n"
578
+ f"Generate the backend application:"
579
+ )
580
+
581
+ # Respect model's max position embeddings (GPT-2/DialoGPT is typically 1024)
582
+ model_max_len = getattr(self.tokenizer, 'model_max_length', 1024)
583
+ max_len = 1024 if model_max_len is None or model_max_len > 100000 else min(1024, int(model_max_len))
584
+
585
+ inputs = self.tokenizer.encode(input_text, return_tensors='pt', truncation=True, max_length=max_len)
586
+
587
+ with torch.no_grad():
588
+ outputs = self.model.generate(
589
+ inputs,
590
+ max_length=max_len,
591
+ num_return_sequences=1,
592
+ temperature=0.7,
593
+ do_sample=True,
594
+ pad_token_id=self.tokenizer.eos_token_id
595
+ )
596
+
597
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
598
+ return generated_text[len(input_text):]
599
+
600
+
601
+ class ModelEvaluator:
602
+ """Evaluates model performance"""
603
+
604
+ def __init__(self):
605
+ self.metrics: Dict[str, float] = {}
606
+
607
+ def evaluate_code_quality(self, generated_code: str, language: str) -> Dict[str, float]:
608
+ """Evaluate generated code quality"""
609
+ metrics = {
610
+ 'syntax_correctness': self._check_syntax(generated_code, language),
611
+ 'completeness': self._check_completeness(generated_code),
612
+ 'best_practices': self._check_best_practices(generated_code, language)
613
+ }
614
+
615
+ return metrics
616
+
617
+ def _check_syntax(self, code: str, language: str) -> float:
618
+ """Check if generated code has valid syntax"""
619
+ if language == 'python':
620
+ try:
621
+ ast.parse(code)
622
+ return 1.0
623
+ except SyntaxError:
624
+ return 0.0
625
+ elif language == 'javascript':
626
+ if '{' in code and '}' in code:
627
+ return 0.8
628
+ return 0.5
629
+
630
+ return 0.5
631
+
632
+ def _check_completeness(self, code: str) -> float:
633
+ """Check if code appears complete"""
634
+ completeness_indicators = [
635
+ 'import', 'require', 'function', 'def', 'class',
636
+ 'app.', 'router.', '@app.', 'app.listen', 'if __name__'
637
+ ]
638
+
639
+ indicators_found = sum(1 for indicator in completeness_indicators if indicator in code)
640
+ return min(indicators_found / 3.0, 1.0)
641
+
642
+ def _check_best_practices(self, code: str, language: str) -> float:
643
+ """Check adherence to best practices"""
644
+ best_practices_score = 0.0
645
+
646
+ if 'try:' in code or 'catch' in code:
647
+ best_practices_score += 0.2
648
+
649
+ if any(comment in code for comment in ['#', '//', '/*']):
650
+ best_practices_score += 0.2
651
+
652
+ if language == 'python':
653
+ if 'if __name__ == "__main__"' in code:
654
+ best_practices_score += 0.2
655
+ elif language == 'javascript':
656
+ if 'const' in code or 'let' in code:
657
+ best_practices_score += 0.2
658
+
659
+ return min(best_practices_score, 1.0)
660
+
661
+ def benchmark_model(self, model: 'CodeGenerationModel', test_cases: List[Dict]) -> Dict[str, float]:
662
+ """Benchmark model on test cases"""
663
+ total_scores = {'syntax': 0.0, 'completeness': 0.0, 'best_practices': 0.0}
664
+
665
+ for i, test_case in enumerate(test_cases):
666
+ generated_code = model.generate_code(
667
+ test_case['description'],
668
+ test_case['framework'],
669
+ test_case['language']
670
+ )
671
+
672
+ scores = self.evaluate_code_quality(generated_code, test_case['language'])
673
+
674
+ total_scores['syntax'] += scores['syntax_correctness']
675
+ total_scores['completeness'] += scores['completeness']
676
+ total_scores['best_practices'] += scores['best_practices']
677
+
678
+ logger.info(f"Test case {i+1}: {scores}")
679
+
680
+ num_cases = max(1, len(test_cases))
681
+ avg_scores = {key: value / num_cases for key, value in total_scores.items()}
682
+
683
+ return avg_scores
684
+
685
+
686
+ class TrainingPipeline:
687
+ """Main training pipeline orchestrator"""
688
+
689
+ def __init__(self, config: Dict[str, Any]):
690
+ self.config = config
691
+ self.data_collector = DataCollector()
692
+ self.preprocessor = DataPreprocessor(config.get('tokenizer', 'microsoft/DialoGPT-medium'))
693
+ self.model = CodeGenerationModel(config.get('base_model', 'microsoft/DialoGPT-medium'))
694
+ self.evaluator = ModelEvaluator()
695
+
696
+ async def run_full_pipeline(self):
697
+ """Run the complete training pipeline"""
698
+ logger.info("Starting full training pipeline...")
699
+
700
+ logger.info("Step 1: Collecting training data...")
701
+
702
+ if self.data_collector.github_token:
703
+ github_queries = [
704
+ 'express api backend',
705
+ 'fastapi python backend',
706
+ 'django rest api',
707
+ 'nodejs backend server',
708
+ 'flask api backend'
709
+ ]
710
+ await self.data_collector.collect_github_repositories(github_queries, max_repos=50)
711
+
712
+ self.data_collector.generate_synthetic_examples(count=200)
713
+
714
+ self.data_collector.save_dataset('raw_dataset.json')
715
+
716
+ logger.info("Step 2: Preprocessing data...")
717
+ processed_examples = self.preprocessor.preprocess_examples(self.data_collector.collected_examples)
718
+ training_dataset = self.preprocessor.create_training_dataset(processed_examples)
719
+
720
+ logger.info("Step 3: Training model...")
721
+ self.model.fine_tune(training_dataset, output_dir=self.config.get('output_dir', './trained_model'))
722
+
723
+ logger.info("Step 4: Evaluating model...")
724
+ test_cases = [
725
+ {
726
+ 'description': 'REST API for user management with authentication',
727
+ 'framework': 'express',
728
+ 'language': 'javascript'
729
+ },
730
+ {
731
+ 'description': 'FastAPI backend for e-commerce platform',
732
+ 'framework': 'fastapi',
733
+ 'language': 'python'
734
+ },
735
+ {
736
+ 'description': 'Django REST API for blog platform',
737
+ 'framework': 'django',
738
+ 'language': 'python'
739
+ }
740
+ ]
741
+
742
+ benchmark_results = self.evaluator.benchmark_model(self.model, test_cases)
743
+ logger.info(f"Benchmark results: {benchmark_results}")
744
+
745
+ logger.info("Training pipeline completed!")
746
+ return benchmark_results
747
+
748
+
749
+ if __name__ == "__main__":
750
+ config = {
751
+ 'base_model': 'microsoft/DialoGPT-medium',
752
+ 'tokenizer': 'microsoft/DialoGPT-medium',
753
+ 'output_dir': './backend_code_model',
754
+ 'github_token': os.getenv('GITHUB_TOKEN'),
755
+ }
756
+
757
+ pipeline = TrainingPipeline(config)
758
+
759
+ asyncio.run(pipeline.run_full_pipeline())
760
+
761
+ logger.info("\nTesting trained model...")
762
+ generated_code = pipeline.model.generate_code(
763
+ description="Create a REST API for managing tasks with CRUD operations",
764
+ framework="express",
765
+ language="javascript"
766
+ )
767
+
768
+ print("\nGenerated Code:")
769
+ print("=" * 50)
770
+ print(generated_code)
771
+
772
+