package main

import (
	"crypto/hmac"
	"crypto/sha256"
	"encoding/base64"
	"encoding/json"
	"flag"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/url"
	"os"
	"strings"
	"time"

	"github.com/gorilla/websocket"
)

var domainUrlMap = map[string]string{
	"lite":        "wss://spark-api.xf-yun.com/v1.1/chat",
	"generalv3":   "wss://spark-api.xf-yun.com/v3.1/chat",
	"pro-128k":    "wss://spark-api.xf-yun.com/v3.1/chat",
	"generalv3.5": "wss://spark-api.xf-yun.com/v3.5/chat",
	"max-32k":     "wss://spark-api.xf-yun.com/v3.5/chat",
	"4.0Ultra":    "wss://spark-api.xf-yun.com/v4.0/chat",
	"kjwx":        "wss://spark-api.xf-yun.com/v1.1/chat",
}

func main() {
	var (
		appid     = flag.String("appid", "", "应用appid")
		apiSecret = flag.String("apiSecret", "", "API Secret")
		apiKey    = flag.String("apiKey", "", "API Key")
		question  = flag.String("q", "你是谁，可以干什么？", "问题")
		domain    = flag.String("domain", "lite", "模型版本")
		temp      = flag.Float64("temp", 0.5, "温度")
		maxTokens = flag.Int("maxTokens", 2048, "最大token")
		topK      = flag.Int("topK", 4, "topK")
		hostUrl   = flag.String("hostUrl", "", "自定义地址")
	)
	flag.Parse()

	if *appid == "" || *apiSecret == "" || *apiKey == "" {
		fmt.Println("错误: appid, apiSecret, apiKey 为必填参数")
		flag.PrintDefaults()
		os.Exit(1)
	}

	finalHostUrl := *hostUrl
	if finalHostUrl == "" {
		if url, ok := domainUrlMap[*domain]; ok {
			finalHostUrl = url
		} else {
			fmt.Printf("错误: 未知的domain '%s'\n", *domain)
			os.Exit(1)
		}
	}

	fmt.Printf("使用模型: %s\n", *domain)
	fmt.Printf("请求地址: %s\n", finalHostUrl)
	fmt.Printf("问题: %s\n\n", *question)

	d := websocket.Dialer{
		HandshakeTimeout: 5 * time.Second,
	}

	authUrl := assembleAuthUrl1(finalHostUrl, *apiKey, *apiSecret)
	conn, resp, err := d.Dial(authUrl, nil)
	if err != nil {
		panic(readResp(resp) + err.Error())
	} else if resp.StatusCode != 101 {
		panic(readResp(resp) + err.Error())
	}

	go func() {
		data := genParams1(*appid, *question, *domain, *temp, *maxTokens, *topK)
		conn.WriteJSON(data)
	}()

	var answer = ""
	for {
		_, msg, err := conn.ReadMessage()
		if err != nil {
			fmt.Println("read message error:", err)
			break
		}

		var data map[string]interface{}
		err1 := json.Unmarshal(msg, &data)
		if err1 != nil {
			fmt.Println("Error parsing JSON:", err)
			return
		}

		payload := data["payload"].(map[string]interface{})
		choices := payload["choices"].(map[string]interface{})
		header := data["header"].(map[string]interface{})
		code := header["code"].(float64)

		if code != 0 {
			fmt.Println("错误:", data["payload"])
			return
		}

		status := choices["status"].(float64)
		text := choices["text"].([]interface{})
		content := text[0].(map[string]interface{})["content"].(string)

		if status != 2 {
			answer += content
			fmt.Print(content)
		} else {
			fmt.Println("\n\n收到最终结果")
			answer += content
			fmt.Print(content)

			usage := payload["usage"].(map[string]interface{})
			temp := usage["text"].(map[string]interface{})
			totalTokens := temp["total_tokens"].(float64)
			fmt.Printf("\n\ntotal_tokens: %.0f\n", totalTokens)
			conn.Close()
			break
		}
	}

	fmt.Println("\n\n完整回答:")
	fmt.Println(answer)
	time.Sleep(1 * time.Second)
}

func genParams1(appid, question, domain string, temperature float64, maxTokens, topK int) map[string]interface{} {
	messages := []Message{
		{Role: "user", Content: question},
	}

	data := map[string]interface{}{
		"header": map[string]interface{}{
			"app_id": appid,
		},
		"parameter": map[string]interface{}{
			"chat": map[string]interface{}{
				"domain":      domain,
				"temperature": temperature,
				"top_k":       topK,
				"max_tokens":  maxTokens,
				"auditing":    "default",
			},
		},
		"payload": map[string]interface{}{
			"message": map[string]interface{}{
				"text": messages,
			},
		},
	}
	return data
}

func assembleAuthUrl1(hosturl string, apiKey, apiSecret string) string {
	ul, err := url.Parse(hosturl)
	if err != nil {
		fmt.Println(err)
	}

	// 使用固定格式，避免RFC1123在不同系统的差异
	date := time.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT")

	// 参与签名的字段 host, date, request-line
	signString := []string{
		"host: " + ul.Host,
		"date: " + date,
		"GET " + ul.Path + " HTTP/1.1",
	}

	// 拼接签名字符串（注意：这里使用原始的 sgin 变量名）
	sgin := strings.Join(signString, "\n")

	// ===== 调试信息 =====
	//fmt.Println("=== 调试信息 ===")
	//fmt.Println("Host:", ul.Host)
	//fmt.Println("Path:", ul.Path)
	//fmt.Println("Date:", date)
	//fmt.Printf("Sign String:\n%s\n", sgin)
	//fmt.Println("API Key:", apiKey)
	//if len(apiSecret) > 8 {
	//	fmt.Println("API Secret:", apiSecret[:8]+"...")
	//}
	// ===================

	// 签名结果
	sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret)
	//fmt.Println("Signature:", sha)

	// 构建请求参数
	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
		"hmac-sha256", "host date request-line", sha)

	// 将请求参数使用base64编码
	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))

	v := url.Values{}
	v.Add("host", ul.Host)
	v.Add("date", date)
	v.Add("authorization", authorization)

	callurl := hosturl + "?" + v.Encode()
	return callurl
}

func HmacWithShaTobase64(algorithm, data, key string) string {
	mac := hmac.New(sha256.New, []byte(key))
	mac.Write([]byte(data))
	encodeData := mac.Sum(nil)
	return base64.StdEncoding.EncodeToString(encodeData)
}

func readResp(resp *http.Response) string {
	if resp == nil {
		return ""
	}
	b, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		panic(err)
	}
	return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
}

type Message struct {
	Role    string `json:"role"`
	Content string `json:"content"`
}
