package api import ( "encoding/json" "log" "sync" "time" "git.rouggy.com/rouggy/ShackMaster/pkg/protocol" "github.com/gorilla/websocket" ) type Hub struct { clients map[*Client]bool broadcast chan *protocol.WebSocketMessage register chan *Client unregister chan *Client mu sync.RWMutex } type Client struct { hub *Hub conn *websocket.Conn send chan *protocol.WebSocketMessage } func NewHub() *Hub { return &Hub{ clients: make(map[*Client]bool), broadcast: make(chan *protocol.WebSocketMessage, 256), register: make(chan *Client), unregister: make(chan *Client), } } func (h *Hub) Run() { for { select { case client := <-h.register: h.mu.Lock() h.clients[client] = true h.mu.Unlock() log.Printf("Client connected, total: %d", len(h.clients)) case client := <-h.unregister: h.mu.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) close(client.send) } h.mu.Unlock() log.Printf("Client disconnected, total: %d", len(h.clients)) case message := <-h.broadcast: h.mu.RLock() for client := range h.clients { select { case client.send <- message: default: // Client's send buffer is full, close it h.mu.RUnlock() h.unregister <- client h.mu.RLock() } } h.mu.RUnlock() } } } func (h *Hub) Broadcast(msg *protocol.WebSocketMessage) { h.broadcast <- msg } func (h *Hub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } const ( writeWait = 10 * time.Second pongWait = 60 * time.Second pingPeriod = (pongWait * 9) / 10 maxMessageSize = 512 * 1024 // 512KB ) func (c *Client) readPump() { defer func() { c.hub.unregister <- c c.conn.Close() }() c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetReadLimit(maxMessageSize) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { var msg protocol.WebSocketMessage err := c.conn.ReadJSON(&msg) if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error: %v", err) } break } // Handle incoming commands from client log.Printf("Received message: type=%s, device=%s", msg.Type, msg.Device) // Commands are handled via REST API, not WebSocket // WebSocket is primarily for server -> client updates // Client should use REST endpoints for commands } } func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.conn.Close() }() for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // Hub closed the channel c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } if err := c.conn.WriteJSON(message); err != nil { log.Printf("Error writing message: %v", err) return } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } func ServeWs(hub *Hub, conn *websocket.Conn) { client := &Client{ hub: hub, conn: conn, send: make(chan *protocol.WebSocketMessage, 256), } client.hub.register <- client // Send initial status go func() { time.Sleep(100 * time.Millisecond) client.send <- &protocol.WebSocketMessage{ Type: protocol.MsgTypeStatus, Data: map[string]string{"status": "connected"}, Timestamp: time.Now(), } }() go client.writePump() go client.readPump() } // BroadcastStatusUpdate sends a status update to all connected clients func (h *Hub) BroadcastStatusUpdate(status interface{}) { data, err := json.Marshal(status) if err != nil { log.Printf("Error marshaling status: %v", err) return } var statusMap map[string]interface{} if err := json.Unmarshal(data, &statusMap); err != nil { log.Printf("Error unmarshaling status: %v", err) return } h.Broadcast(&protocol.WebSocketMessage{ Type: protocol.MsgTypeUpdate, Data: statusMap, Timestamp: time.Now(), }) }