// SPDX-FileCopyrightText: 2021 Luke Granger-Brown <depot@lukegb.com>
//
// SPDX-License-Identifier: Apache-2.0

package hashfs

import (
	"crypto/sha512"
	"encoding/hex"
	"errors"
	"fmt"
	"io/fs"
	"path"
)

// adaptError adapts an error from an underlying filesystem.
func adaptError(err error, name string) error {
	var pErr *fs.PathError
	if errors.As(err, &pErr) {
		pErr.Path = name
		return pErr
	}
	return err
}

type FS struct {
	fs fs.ReadFileFS

	// We need:
	// original name -> new name
	// new name -> original name
	toHashName  map[string]string
	toPlainName map[string]string
}

type staticFSFileInfo struct {
	fs.FileInfo
	newName string
}

func (fi staticFSFileInfo) Name() string { return path.Base(fi.newName) }

type staticFSDirEntry struct {
	fs.DirEntry
	newName string
}

func (dent staticFSDirEntry) Name() string { return path.Base(dent.newName) }
func (dent staticFSDirEntry) Info() (fs.FileInfo, error) {
	fi, err := dent.DirEntry.Info()
	if err != nil {
		return nil, err
	}
	return staticFSFileInfo{fi, dent.newName}, nil
}

type staticFSFile struct {
	fs.File
	newName string
}

func (f *staticFSFile) Stat() (fs.FileInfo, error) {
	fi, err := f.File.Stat()
	if err != nil {
		return nil, err
	}
	return staticFSFileInfo{fi, f.newName}, nil
}

type staticFSDirFile struct {
	fs.ReadDirFile
	basePath string
	s        *FS
}

func (f staticFSDirFile) ReadDir(n int) ([]fs.DirEntry, error) {
	dents, err := f.ReadDirFile.ReadDir(n)
	if err != nil {
		return nil, err
	}
	for n, dent := range dents {
		fpath := path.Join(f.basePath, dent.Name())
		if hashName, ok := f.s.toHashName[fpath]; ok {
			dents[n] = staticFSDirEntry{dent, hashName}
		}
	}
	return dents, nil
}

// build computes the hashes needed to serve files.
func (s *FS) build() error {
	toHashName := make(map[string]string)
	toPlainName := make(map[string]string)
	if err := fs.WalkDir(s.fs, ".", func(fpath string, d fs.DirEntry, err error) error {
		if err != nil {
			return err
		}
		if d.IsDir() {
			// We don't care about directories.
			return nil
		}

		// Compute the hash of this file...
		bs, err := s.fs.ReadFile(fpath)
		if err != nil {
			return fmt.Errorf("reading %q: %w", fpath, err)
		}
		fileHash := sha512.Sum512(bs)
		hexDigest := hex.EncodeToString(fileHash[:6])

		// and fix it up in the filename.
		oldName := fpath
		ext := path.Ext(oldName)
		var newName string
		if len(ext) == 0 {
			newName = fmt.Sprintf("%s.%s", oldName, hexDigest)
		} else {
			newName = fmt.Sprintf("%s.%s%s", oldName[:len(oldName)-len(ext)], hexDigest, ext)
		}
		toHashName[oldName] = newName
		toPlainName[newName] = oldName

		return nil
	}); err != nil {
		return err
	}
	s.toHashName = toHashName
	s.toPlainName = toPlainName
	return nil
}

// LookupHashedName looks up the filename for the given original name.
func (s *FS) LookupHashedName(name string) (string, bool) {
	n, ok := s.toHashName[name]
	return n, ok
}

// Open opens the named file.
func (s *FS) Open(name string) (fs.File, error) {
	fn, ok := s.toPlainName[name]
	if !ok {
		// Try opening it as a plain name.
		fn = name
	}
	f, err := s.fs.Open(fn)
	if err != nil {
		return nil, adaptError(err, name)
	}
	if rdf, ok := f.(fs.ReadDirFile); ok {
		return &staticFSDirFile{rdf, name, s}, nil
	}
	return &staticFSFile{f, name}, nil
}

// ReadFile provides an optimised ReadFile implementation.
func (s *FS) ReadFile(name string) ([]byte, error) {
	fn, ok := s.toPlainName[name]
	if !ok {
		// Try opening it as a plain name.
		fn = name
	}
	f, err := s.fs.ReadFile(fn)
	if err != nil {
		return nil, adaptError(err, name)
	}
	return f, nil
}

func New(fs fs.ReadFileFS) *FS {
	s := &FS{
		fs: fs,
	}
	if err := s.build(); err != nil {
		panic(err)
	}
	return s
}