Edwin Salguero
feat(ui): add robust multi-interface UI system (Streamlit, Dash, Jupyter, WebSocket) with launcher, docs, and integration tests [skip ci]
9f44dc9
| """ | |
| WebSocket Server for Real-time Trading Data | |
| Provides real-time updates for: | |
| - Market data streaming | |
| - Trading signals | |
| - Portfolio updates | |
| - System alerts | |
| """ | |
| import asyncio | |
| import websockets | |
| import json | |
| import logging | |
| import threading | |
| import time | |
| from datetime import datetime, timedelta | |
| from typing import Dict, Any, List, Optional | |
| import pandas as pd | |
| import os | |
| import sys | |
| # Add project root to path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from agentic_ai_system.main import load_config | |
| from agentic_ai_system.data_ingestion import load_data, add_technical_indicators | |
| from agentic_ai_system.alpaca_broker import AlpacaBroker | |
| from agentic_ai_system.finrl_agent import FinRLAgent, FinRLConfig | |
| class TradingWebSocketServer: | |
| def __init__(self, host="localhost", port=8765): | |
| self.host = host | |
| self.port = port | |
| self.clients = set() | |
| self.config = None | |
| self.alpaca_broker = None | |
| self.finrl_agent = None | |
| self.trading_active = False | |
| self.market_data = None | |
| self.portfolio_data = {} | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| self.logger = logging.getLogger(__name__) | |
| async def register(self, websocket): | |
| """Register a new client""" | |
| self.clients.add(websocket) | |
| self.logger.info(f"Client connected. Total clients: {len(self.clients)}") | |
| # Send initial data | |
| await self.send_initial_data(websocket) | |
| async def unregister(self, websocket): | |
| """Unregister a client""" | |
| self.clients.remove(websocket) | |
| self.logger.info(f"Client disconnected. Total clients: {len(self.clients)}") | |
| async def send_initial_data(self, websocket): | |
| """Send initial data to new client""" | |
| initial_data = { | |
| "type": "initial_data", | |
| "timestamp": datetime.now().isoformat(), | |
| "config": self.config, | |
| "portfolio": self.portfolio_data, | |
| "trading_status": self.trading_active | |
| } | |
| await websocket.send(json.dumps(initial_data)) | |
| async def broadcast(self, message): | |
| """Broadcast message to all connected clients""" | |
| if self.clients: | |
| message_str = json.dumps(message) | |
| await asyncio.gather( | |
| *[client.send(message_str) for client in self.clients], | |
| return_exceptions=True | |
| ) | |
| async def handle_market_data(self): | |
| """Handle real-time market data updates""" | |
| while True: | |
| try: | |
| if self.config and self.alpaca_broker: | |
| # Get real-time market data | |
| symbol = self.config['trading']['symbol'] | |
| # Get current price | |
| current_price = await self.get_current_price(symbol) | |
| if current_price: | |
| market_update = { | |
| "type": "market_data", | |
| "timestamp": datetime.now().isoformat(), | |
| "symbol": symbol, | |
| "price": current_price, | |
| "volume": await self.get_current_volume(symbol) | |
| } | |
| await self.broadcast(market_update) | |
| self.logger.info(f"Broadcasted market data for {symbol}: ${current_price}") | |
| await asyncio.sleep(1) # Update every second | |
| except Exception as e: | |
| self.logger.error(f"Error in market data handler: {e}") | |
| await asyncio.sleep(5) # Wait before retrying | |
| async def handle_portfolio_updates(self): | |
| """Handle portfolio updates""" | |
| while True: | |
| try: | |
| if self.alpaca_broker: | |
| # Get portfolio information | |
| account_info = self.alpaca_broker.get_account_info() | |
| positions = self.alpaca_broker.get_positions() | |
| if account_info: | |
| portfolio_update = { | |
| "type": "portfolio_update", | |
| "timestamp": datetime.now().isoformat(), | |
| "account": { | |
| "buying_power": float(account_info['buying_power']), | |
| "portfolio_value": float(account_info['portfolio_value']), | |
| "equity": float(account_info['equity']), | |
| "cash": float(account_info['cash']) | |
| }, | |
| "positions": positions if positions else [] | |
| } | |
| await self.broadcast(portfolio_update) | |
| self.portfolio_data = portfolio_update | |
| await asyncio.sleep(5) # Update every 5 seconds | |
| except Exception as e: | |
| self.logger.error(f"Error in portfolio updates: {e}") | |
| await asyncio.sleep(10) # Wait before retrying | |
| async def handle_trading_signals(self): | |
| """Handle trading signals from FinRL agent""" | |
| while True: | |
| try: | |
| if self.trading_active and self.finrl_agent and self.market_data is not None: | |
| # Generate trading signals | |
| signal = await self.generate_trading_signal() | |
| if signal: | |
| signal_update = { | |
| "type": "trading_signal", | |
| "timestamp": datetime.now().isoformat(), | |
| "signal": signal | |
| } | |
| await self.broadcast(signal_update) | |
| self.logger.info(f"Broadcasted trading signal: {signal}") | |
| await asyncio.sleep(10) # Generate signals every 10 seconds | |
| except Exception as e: | |
| self.logger.error(f"Error in trading signals: {e}") | |
| await asyncio.sleep(30) # Wait before retrying | |
| async def get_current_price(self, symbol): | |
| """Get current price for symbol""" | |
| try: | |
| if self.alpaca_broker: | |
| # Get latest price from Alpaca | |
| latest_trade = self.alpaca_broker.get_latest_trade(symbol) | |
| if latest_trade: | |
| return float(latest_trade['p']) | |
| return None | |
| except Exception as e: | |
| self.logger.error(f"Error getting current price: {e}") | |
| return None | |
| async def get_current_volume(self, symbol): | |
| """Get current volume for symbol""" | |
| try: | |
| if self.alpaca_broker: | |
| # Get latest trade volume | |
| latest_trade = self.alpaca_broker.get_latest_trade(symbol) | |
| if latest_trade: | |
| return int(latest_trade['s']) | |
| return None | |
| except Exception as e: | |
| self.logger.error(f"Error getting current volume: {e}") | |
| return None | |
| async def generate_trading_signal(self): | |
| """Generate trading signal using FinRL agent""" | |
| try: | |
| if self.finrl_agent and self.market_data is not None: | |
| # Use recent data for prediction | |
| recent_data = self.market_data.tail(100) | |
| prediction_result = self.finrl_agent.predict( | |
| data=recent_data, | |
| config=self.config, | |
| use_real_broker=False | |
| ) | |
| if prediction_result['success']: | |
| # Generate signal based on prediction | |
| current_price = await self.get_current_price(self.config['trading']['symbol']) | |
| if current_price: | |
| signal = { | |
| "action": "HOLD", # Default action | |
| "confidence": 0.5, | |
| "price": current_price, | |
| "reasoning": "Model prediction" | |
| } | |
| # Determine action based on prediction | |
| if prediction_result['total_return'] > 0.02: # 2% positive return | |
| signal["action"] = "BUY" | |
| signal["confidence"] = min(0.9, 0.5 + abs(prediction_result['total_return'])) | |
| elif prediction_result['total_return'] < -0.02: # 2% negative return | |
| signal["action"] = "SELL" | |
| signal["confidence"] = min(0.9, 0.5 + abs(prediction_result['total_return'])) | |
| return signal | |
| return None | |
| except Exception as e: | |
| self.logger.error(f"Error generating trading signal: {e}") | |
| return None | |
| async def handle_client_message(self, websocket, message): | |
| """Handle incoming client messages""" | |
| try: | |
| data = json.loads(message) | |
| message_type = data.get("type") | |
| if message_type == "load_config": | |
| # Load configuration | |
| config_file = data.get("config_file", "config.yaml") | |
| self.config = load_config(config_file) | |
| response = { | |
| "type": "config_loaded", | |
| "success": True, | |
| "config": self.config | |
| } | |
| await websocket.send(json.dumps(response)) | |
| elif message_type == "connect_alpaca": | |
| # Connect to Alpaca | |
| api_key = data.get("api_key") | |
| secret_key = data.get("secret_key") | |
| if api_key and secret_key: | |
| self.config['alpaca']['api_key'] = api_key | |
| self.config['alpaca']['secret_key'] = secret_key | |
| self.config['execution']['broker_api'] = 'alpaca_paper' | |
| self.alpaca_broker = AlpacaBroker(self.config) | |
| response = { | |
| "type": "alpaca_connected", | |
| "success": True | |
| } | |
| await websocket.send(json.dumps(response)) | |
| else: | |
| response = { | |
| "type": "alpaca_connected", | |
| "success": False, | |
| "error": "Missing API credentials" | |
| } | |
| await websocket.send(json.dumps(response)) | |
| elif message_type == "start_trading": | |
| # Start trading | |
| self.trading_active = True | |
| response = { | |
| "type": "trading_started", | |
| "success": True | |
| } | |
| await websocket.send(json.dumps(response)) | |
| # Broadcast to all clients | |
| await self.broadcast({ | |
| "type": "trading_status", | |
| "active": True, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| elif message_type == "stop_trading": | |
| # Stop trading | |
| self.trading_active = False | |
| response = { | |
| "type": "trading_stopped", | |
| "success": True | |
| } | |
| await websocket.send(json.dumps(response)) | |
| # Broadcast to all clients | |
| await self.broadcast({ | |
| "type": "trading_status", | |
| "active": False, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| elif message_type == "load_data": | |
| # Load market data | |
| if self.config: | |
| self.market_data = load_data(self.config) | |
| if self.market_data is not None: | |
| self.market_data = add_technical_indicators(self.market_data) | |
| response = { | |
| "type": "data_loaded", | |
| "success": True, | |
| "data_points": len(self.market_data) | |
| } | |
| else: | |
| response = { | |
| "type": "data_loaded", | |
| "success": False, | |
| "error": "Failed to load data" | |
| } | |
| else: | |
| response = { | |
| "type": "data_loaded", | |
| "success": False, | |
| "error": "Configuration not loaded" | |
| } | |
| await websocket.send(json.dumps(response)) | |
| elif message_type == "train_model": | |
| # Train FinRL model | |
| if self.market_data is not None: | |
| algorithm = data.get("algorithm", "PPO") | |
| learning_rate = data.get("learning_rate", 0.0003) | |
| training_steps = data.get("training_steps", 100000) | |
| finrl_config = FinRLConfig( | |
| algorithm=algorithm, | |
| learning_rate=learning_rate, | |
| batch_size=64, | |
| buffer_size=1000000, | |
| learning_starts=100, | |
| gamma=0.99, | |
| tau=0.005, | |
| train_freq=1, | |
| gradient_steps=1, | |
| verbose=1, | |
| tensorboard_log='logs/finrl_tensorboard' | |
| ) | |
| self.finrl_agent = FinRLAgent(finrl_config) | |
| # Train in background thread | |
| def train_model(): | |
| try: | |
| result = self.finrl_agent.train( | |
| data=self.market_data, | |
| config=self.config, | |
| total_timesteps=training_steps, | |
| use_real_broker=False | |
| ) | |
| # Broadcast training completion | |
| asyncio.create_task(self.broadcast({ | |
| "type": "training_completed", | |
| "success": result['success'], | |
| "result": result | |
| })) | |
| except Exception as e: | |
| asyncio.create_task(self.broadcast({ | |
| "type": "training_completed", | |
| "success": False, | |
| "error": str(e) | |
| })) | |
| training_thread = threading.Thread(target=train_model) | |
| training_thread.daemon = True | |
| training_thread.start() | |
| response = { | |
| "type": "training_started", | |
| "success": True | |
| } | |
| else: | |
| response = { | |
| "type": "training_started", | |
| "success": False, | |
| "error": "Market data not loaded" | |
| } | |
| await websocket.send(json.dumps(response)) | |
| else: | |
| # Unknown message type | |
| response = { | |
| "type": "error", | |
| "message": f"Unknown message type: {message_type}" | |
| } | |
| await websocket.send(json.dumps(response)) | |
| except json.JSONDecodeError: | |
| response = { | |
| "type": "error", | |
| "message": "Invalid JSON message" | |
| } | |
| await websocket.send(json.dumps(response)) | |
| except Exception as e: | |
| response = { | |
| "type": "error", | |
| "message": f"Server error: {str(e)}" | |
| } | |
| await websocket.send(json.dumps(response)) | |
| async def websocket_handler(self, websocket, path): | |
| """Main WebSocket handler""" | |
| await self.register(websocket) | |
| try: | |
| async for message in websocket: | |
| await self.handle_client_message(websocket, message) | |
| except websockets.exceptions.ConnectionClosed: | |
| pass | |
| finally: | |
| await self.unregister(websocket) | |
| async def start_server(self): | |
| """Start the WebSocket server""" | |
| # Start background tasks | |
| asyncio.create_task(self.handle_market_data()) | |
| asyncio.create_task(self.handle_portfolio_updates()) | |
| asyncio.create_task(self.handle_trading_signals()) | |
| # Start WebSocket server | |
| server = await websockets.serve( | |
| self.websocket_handler, | |
| self.host, | |
| self.port | |
| ) | |
| self.logger.info(f"WebSocket server started on ws://{self.host}:{self.port}") | |
| # Keep server running | |
| await server.wait_closed() | |
| def run_server(self): | |
| """Run the server in a separate thread""" | |
| def run(): | |
| asyncio.run(self.start_server()) | |
| server_thread = threading.Thread(target=run) | |
| server_thread.daemon = True | |
| server_thread.start() | |
| return server_thread | |
| def create_websocket_server(host="localhost", port=8765): | |
| """Create and return a WebSocket server instance""" | |
| return TradingWebSocketServer(host=host, port=port) |