depot/go/nix/nixwire/nixwire.go

285 lines
6.5 KiB
Go
Raw Permalink Normal View History

package nixwire
import (
"encoding/binary"
2022-10-09 22:48:18 +00:00
"fmt"
"io"
2024-11-16 15:30:41 +00:00
"git.lukegb.com/lukegb/depot/go/nix/nixdrv"
)
type Serializer struct {
io.Writer
}
func (w Serializer) WritePadding(n int64) (int64, error) {
if n%8 > 0 {
n, err := w.Write(make([]byte, 8-(n%8)))
return int64(n), err
}
return 0, nil
}
func (w Serializer) WriteUint64(n uint64) (int64, error) {
buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, n)
wrote, err := w.Write(buf)
return int64(wrote), err
}
func (w Serializer) WriteBytes(bs []byte) (int64, error) {
nSize, err := w.WriteUint64(uint64(len(bs)))
if err != nil {
return int64(nSize), err
}
nData, err := w.Write(bs)
if err != nil {
return int64(nSize) + int64(nData), err
}
nPad, err := w.WritePadding(int64(len(bs)))
return int64(nSize) + int64(nData) + int64(nPad), err
}
2022-10-09 22:48:18 +00:00
func (w Serializer) WriteString(s string) (int64, error) {
return w.WriteBytes([]byte(s))
}
func (w Serializer) WriteStrings(ss []string) (int64, error) {
nLen, err := w.WriteUint64(uint64(len(ss)))
if err != nil {
return nLen, err
}
var nTotal int64
for _, s := range ss {
nStr, err := w.WriteString(s)
if err != nil {
return nLen + nTotal, err
}
nTotal += nStr
}
return nLen + nTotal, nil
}
func (w Serializer) WriteStringSet(s map[string]bool) (int64, error) {
ss := make([]string, 0, len(s))
for k := range s {
ss = append(ss, k)
}
return w.WriteStrings(ss)
}
func (w Serializer) WriteDerivation(drv *nixdrv.BasicDerivation) (int64, error) {
nRunningTotal, err := w.WriteUint64(uint64(len(drv.Outputs)))
if err != nil {
return nRunningTotal, err
}
for outputName, output := range drv.Outputs {
nOutputName, err := w.WriteString(outputName)
nRunningTotal += nOutputName
if err != nil {
return nRunningTotal, err
}
nOutputPath, err := w.WriteString(output.Path)
nRunningTotal += nOutputPath
if err != nil {
return nRunningTotal, err
}
nOutputHashAlgo, err := w.WriteString(output.HashAlgorithm)
nRunningTotal += nOutputHashAlgo
if err != nil {
return nRunningTotal, err
}
nOutputHash, err := w.WriteString(output.Hash)
nRunningTotal += nOutputHash
if err != nil {
return nRunningTotal, err
}
}
nInputSrcs, err := w.WriteStringSet(drv.InputSrcs)
nRunningTotal += nInputSrcs
if err != nil {
return nRunningTotal, err
}
nPlatform, err := w.WriteString(drv.Platform)
nRunningTotal += nPlatform
if err != nil {
return nRunningTotal, err
}
nBuilder, err := w.WriteString(drv.Builder)
nRunningTotal += nBuilder
if err != nil {
return nRunningTotal, err
}
nArgs, err := w.WriteStrings(drv.Args)
nRunningTotal += nArgs
if err != nil {
return nRunningTotal, err
}
nEnvSize, err := w.WriteUint64(uint64(len(drv.Env)))
nRunningTotal += nEnvSize
if err != nil {
return nRunningTotal, err
}
for k, v := range drv.Env {
nEnvKey, err := w.WriteString(k)
nRunningTotal += nEnvKey
if err != nil {
return nRunningTotal, err
}
nEnvValue, err := w.WriteString(v)
nRunningTotal += nEnvValue
if err != nil {
return nRunningTotal, err
}
}
return nRunningTotal, nil
}
2022-10-09 22:48:18 +00:00
type Deserializer struct {
io.Reader
}
func (r Deserializer) ReadPadding(n uint64) error {
if n%8 == 0 {
return nil
}
buf := make([]byte, 8-(int(n)%8))
if _, err := io.ReadFull(r, buf); err != nil {
return fmt.Errorf("reading padding: %w", err)
}
for _, n := range buf {
if n != 0 {
return fmt.Errorf("padding was not all 0")
}
}
return nil
}
func (r Deserializer) ReadUint64() (uint64, error) {
buf := make([]byte, 8)
if _, err := io.ReadFull(r, buf); err != nil {
return 0, fmt.Errorf("reading uint64: %w", err)
}
return binary.LittleEndian.Uint64(buf), nil
}
func (r Deserializer) ReadBytes() ([]byte, error) {
2022-10-09 22:48:18 +00:00
strLen, err := r.ReadUint64()
if err != nil {
return nil, fmt.Errorf("reading length: %w", err)
2022-10-09 22:48:18 +00:00
}
strBuf := make([]byte, int(strLen))
if _, err := io.ReadFull(r, strBuf); err != nil {
return nil, fmt.Errorf("reading string buffer: %w", err)
2022-10-09 22:48:18 +00:00
}
if err := r.ReadPadding(strLen); err != nil {
return nil, fmt.Errorf("reading string padding: %w", err)
2022-10-09 22:48:18 +00:00
}
return strBuf, nil
}
func (r Deserializer) ReadString() (string, error) {
bs, err := r.ReadBytes()
if err != nil {
return "", err
}
return string(bs), nil
2022-10-09 22:48:18 +00:00
}
func (r Deserializer) ReadStrings() ([]string, error) {
setLen, err := r.ReadUint64()
if err != nil {
return nil, fmt.Errorf("reading string set length: %w", err)
}
xs := make([]string, int(setLen))
for n := uint64(0); n < setLen; n++ {
xs[n], err = r.ReadString()
if err != nil {
return nil, fmt.Errorf("reading string %d of %d: %w", n, setLen, err)
}
}
return xs, nil
}
func (r Deserializer) ReadStringSet() (map[string]bool, error) {
strs, err := r.ReadStrings()
if err != nil {
return nil, err
}
m := make(map[string]bool, len(strs))
for _, k := range strs {
m[k] = true
}
return m, nil
}
func (r Deserializer) ReadDerivation() (*nixdrv.BasicDerivation, error) {
drv := &nixdrv.BasicDerivation{}
outputsLen, err := r.ReadUint64()
if err != nil {
return nil, fmt.Errorf("reading outputs length: %w", err)
}
drv.Outputs = make(map[string]nixdrv.Output)
for n := uint64(0); n < outputsLen; n++ {
outputName, err := r.ReadString()
if err != nil {
return nil, fmt.Errorf("reading output name: %w", err)
}
var o nixdrv.Output
if o.Path, err = r.ReadString(); err != nil {
return nil, fmt.Errorf("reading output path: %w", err)
}
if o.HashAlgorithm, err = r.ReadString(); err != nil {
return nil, fmt.Errorf("reading output hash algorithm: %w", err)
}
if o.Hash, err = r.ReadString(); err != nil {
return nil, fmt.Errorf("reading output hash: %w", err)
}
drv.Outputs[outputName] = o
}
if drv.InputSrcs, err = r.ReadStringSet(); err != nil {
return nil, fmt.Errorf("reading input srcs: %w", err)
}
if drv.Platform, err = r.ReadString(); err != nil {
return nil, fmt.Errorf("reading platform: %w", err)
}
if drv.Builder, err = r.ReadString(); err != nil {
return nil, fmt.Errorf("reading builder: %w", err)
}
if drv.Args, err = r.ReadStrings(); err != nil {
return nil, fmt.Errorf("reading args: %w", err)
}
envLen, err := r.ReadUint64()
if err != nil {
return nil, fmt.Errorf("reading environment length: %w", err)
}
drv.Env = map[string]string{}
for n := uint64(0); n < envLen; n++ {
key, err := r.ReadString()
if err != nil {
return nil, fmt.Errorf("reading environment variable name: %w", err)
}
value, err := r.ReadString()
if err != nil {
return nil, fmt.Errorf("reading environment variable value: %w", err)
}
drv.Env[key] = value
}
return drv, nil
}