package model

import (
	"crypto/rand"
	"encoding/json"
	"errors"
	"log"
	"strings"

	"github.com/tjfoc/gmsm/sm2"
	"nlt.com/pf/nltconst"
)

type CryptHttpBodyReq struct {
	Request   HttpBodyReq[[]byte]
	Signature []byte
}

type CryptHttpBodyResp struct {
	Response  HttpBodyResp[[]byte]
	Signature []byte
}

type HttpBodyReq[T any] struct {
	Head    ReqHead `json:"head"`
	Request T       `json:"body"`
}

type HttpBodyResp[T any] struct {
	Head     RespHead `json:"head"`
	Response T        `json:"body"`
}

type ReqHead struct {
	RequestTime string `json:"requestTime"`
	ServiceSn   string `json:"serviceSn"`
}

type RespHead struct {
	Code        string `json:"code"`
	ServiceTime string `json:"serviceTime"`
	ServiceSn   string `json:"serviceSn"`
}

func EncryptAndSign[T any](resp HttpBodyResp[T]) (CryptHttpBodyResp, error) {
	var cresp CryptHttpBodyResp
	privateKeyBytes, _ := sm2.GenerateKey(strings.NewReader(nltconst.SM2_PRIVATE_KEY))

	// 对应的公钥
	publicKey := &privateKeyBytes.PublicKey
	body, err := json.Marshal(resp.Response)
	if err != nil {
		log.Println(err.Error())
		return cresp, err
	}
	ciphertext, err := sm2.Encrypt(publicKey, body, rand.Reader, sm2.C1C2C3)
	if err != nil {
		log.Println(err)
		return cresp, err
	}
	cresp.Response.Head = resp.Head
	cresp.Response.Response = ciphertext

	jsonResp, err := json.Marshal(resp)
	if err != nil {
		log.Println(err)
		return cresp, err
	}
	uid := []byte("tk")
	r, s, err := sm2.Sm2Sign(privateKeyBytes, jsonResp, uid, rand.Reader)
	if err != nil {
		log.Println(err)
		return cresp, err
	}
	rBytes := r.Bytes()
	sBytes := s.Bytes()

	signature := append(rBytes, sBytes...)
	cresp.Signature = signature
	return cresp, err
}

func VerifyAndDecrypt[T any](creq CryptHttpBodyReq) (HttpBodyReq[T], error) {
	var req HttpBodyReq[T]
	privateKeyBytes, _ := sm2.GenerateKey(strings.NewReader(nltconst.SM2_PRIVATE_KEY))
	signature := creq.Signature
	r, s, err := sm2.SignDataToSignDigit([]byte(signature))
	if err != nil {
		log.Println(err)
		return req, err
	}
	uid := []byte("tk")
	if sm2.Sm2Verify(&privateKeyBytes.PublicKey, creq.Signature, uid, r, s) {
		tx, err := sm2.Decrypt(privateKeyBytes, creq.Request.Request, sm2.C1C2C3)
		if err != nil {
			log.Println(err)
			return req, errors.New("解密错误")
		}
		req.Head = creq.Request.Head
		err = json.Unmarshal(tx, req.Request)
		if err != nil {
			log.Println(err)
			return req, err
		}
		return req, err
	} else {
		return req, errors.New("验签错误")
	}
}
