package main

import (
	"context"
	"crypto/sha256"
	"encoding/hex"
	"errors"
	"flag"
	"fmt"
	"io/fs"
	"log"
	"net/http"
	"os"
	"path"
	"runtime"
	"strings"
	"sync"
	"syscall"
	"unsafe"
)

var (
	serve    = flag.String("serve", ":11316", "Port number.")
	cacheDir = flag.String("cache_dir", "c:\\temp\\sapid-cache", "Cache dir.")

	sapi           = syscall.NewLazyDLL("sapi4.dll")
	procMakeSamSay = sapi.NewProc("MakeSamSay")

	sayitMutex sync.Mutex
)

type sayResponse struct {
	Bytes []byte
	Err   error
}

type sayRequest struct {
	Text       string
	ResponseCh chan sayResponse
}

func sayRoutine(reqCh chan sayRequest) {
	//defer os.Exit(1)
	//defer log.Printf("sayRoutine is exiting!!!")
	runtime.LockOSThread()
	for req := range reqCh {
		respBytes, err := say(req.Text)
		req.ResponseCh <- sayResponse{respBytes, err}
	}
}

var (
	sayChan chan sayRequest
)

func say(str string) ([]byte, error) {
	td, err := os.MkdirTemp("", "sayit-sapi")
	if err != nil {
		return nil, err
	}
	defer os.RemoveAll(td)

	if err := os.Chdir(td); err != nil {
		return nil, fmt.Errorf("chdir %v: %w", td, err)
	}

	inp := make([]byte, len(str)+1)
	copy(inp, str)
	inpP := &inp[0]

	outfile := make([]byte, 17)
	outfileP := &outfile[0]
	outfilePP := &outfileP
	r1, _, _ := syscall.SyscallN(procMakeSamSay.Addr(), uintptr(unsafe.Pointer(inpP)), uintptr(unsafe.Pointer(outfilePP)))
	if r1 != 1 {
		return nil, fmt.Errorf("MakeSamSay failed")
	}

	wavPath := path.Join(td, strings.TrimSpace(string(outfile[:16])))
	log.Printf("input %q: reading from %v", str, wavPath)
	return os.ReadFile(wavPath)
}

func sayUncached(ctx context.Context, str string) ([]byte, error) {
	respCh := make(chan sayResponse, 1)
	sayChan <- sayRequest{Text: str, ResponseCh: respCh}
	resp := <-respCh
	if resp.Err != nil {
		return nil, resp.Err
	}
	return resp.Bytes, nil
}

func sayCached(ctx context.Context, str string) ([]byte, error) {
	str = strings.TrimSpace(str)
	hb := sha256.Sum256([]byte(str))
	hbHex := hex.EncodeToString(hb[:])

	cachedStr, err := os.ReadFile(path.Join(*cacheDir, hbHex+".txt"))
	if err == nil {
		if strings.TrimSpace(string(cachedStr)) == str {
			return os.ReadFile(path.Join(*cacheDir, hbHex+".wav"))
		}
	} else if !errors.Is(err, fs.ErrNotExist) {
		return nil, fmt.Errorf("checking cache %v/%v: %w", *cacheDir, hbHex, err)
	}

	out, err := sayUncached(ctx, str)
	if err != nil {
		return nil, err
	}

	if err := os.WriteFile(path.Join(*cacheDir, hbHex+".txt"), []byte(str), 0644); err != nil {
		log.Printf("failed writing cache key to %v/%v.txt: %w", *cacheDir, hbHex, err)
		return out, nil
	}
	if err := os.WriteFile(path.Join(*cacheDir, hbHex+".wav"), out, 0644); err != nil {
		log.Printf("failed writing cache data to %v/%v.wav: %w", *cacheDir, hbHex, err)
	}
	return out, nil
}

func makeNoise(rw http.ResponseWriter, r *http.Request) {
	text := r.FormValue("text")
	if text == "" {
		http.Error(rw, "bad request", http.StatusBadRequest)
		return
	}

	sayFunc := sayCached
	if *cacheDir == "" {
		sayFunc = sayUncached
	}
	respBytes, err := sayFunc(r.Context(), text)
	if err != nil {
		log.Printf("rendering %q: %v", text, err)
		http.Error(rw, "failed rendering", http.StatusInternalServerError)
		return
	}

	rw.Header().Set("Content-type", "audio/x-wav")
	rw.Write(respBytes)
}

func main() {
	flag.Parse()

	if *cacheDir != "" {
		if err := os.MkdirAll(*cacheDir, 0700); err != nil {
			log.Fatalf("creating cache dir %v: %v", *cacheDir, err)
		}
	}

	sayChan = make(chan sayRequest)
	go sayRoutine(sayChan)

	http.HandleFunc("/sam", makeNoise)
	log.Printf("Listening on %s", *serve)
	log.Fatal(http.ListenAndServe(*serve, nil))
}