package nixwire import ( "encoding/binary" "fmt" "io" "git.lukegb.com/lukegb/depot/go/nix/nixdrv" ) type Serializer struct { io.Writer } func (w Serializer) WritePadding(n int64) (int64, error) { if n%8 > 0 { n, err := w.Write(make([]byte, 8-(n%8))) return int64(n), err } return 0, nil } func (w Serializer) WriteUint64(n uint64) (int64, error) { buf := make([]byte, 8) binary.LittleEndian.PutUint64(buf, n) wrote, err := w.Write(buf) return int64(wrote), err } func (w Serializer) WriteBytes(bs []byte) (int64, error) { nSize, err := w.WriteUint64(uint64(len(bs))) if err != nil { return int64(nSize), err } nData, err := w.Write(bs) if err != nil { return int64(nSize) + int64(nData), err } nPad, err := w.WritePadding(int64(len(bs))) return int64(nSize) + int64(nData) + int64(nPad), err } func (w Serializer) WriteString(s string) (int64, error) { return w.WriteBytes([]byte(s)) } func (w Serializer) WriteStrings(ss []string) (int64, error) { nLen, err := w.WriteUint64(uint64(len(ss))) if err != nil { return nLen, err } var nTotal int64 for _, s := range ss { nStr, err := w.WriteString(s) if err != nil { return nLen + nTotal, err } nTotal += nStr } return nLen + nTotal, nil } func (w Serializer) WriteStringSet(s map[string]bool) (int64, error) { ss := make([]string, 0, len(s)) for k := range s { ss = append(ss, k) } return w.WriteStrings(ss) } func (w Serializer) WriteDerivation(drv *nixdrv.BasicDerivation) (int64, error) { nRunningTotal, err := w.WriteUint64(uint64(len(drv.Outputs))) if err != nil { return nRunningTotal, err } for outputName, output := range drv.Outputs { nOutputName, err := w.WriteString(outputName) nRunningTotal += nOutputName if err != nil { return nRunningTotal, err } nOutputPath, err := w.WriteString(output.Path) nRunningTotal += nOutputPath if err != nil { return nRunningTotal, err } nOutputHashAlgo, err := w.WriteString(output.HashAlgorithm) nRunningTotal += nOutputHashAlgo if err != nil { return nRunningTotal, err } nOutputHash, err := w.WriteString(output.Hash) nRunningTotal += nOutputHash if err != nil { return nRunningTotal, err } } nInputSrcs, err := w.WriteStringSet(drv.InputSrcs) nRunningTotal += nInputSrcs if err != nil { return nRunningTotal, err } nPlatform, err := w.WriteString(drv.Platform) nRunningTotal += nPlatform if err != nil { return nRunningTotal, err } nBuilder, err := w.WriteString(drv.Builder) nRunningTotal += nBuilder if err != nil { return nRunningTotal, err } nArgs, err := w.WriteStrings(drv.Args) nRunningTotal += nArgs if err != nil { return nRunningTotal, err } nEnvSize, err := w.WriteUint64(uint64(len(drv.Env))) nRunningTotal += nEnvSize if err != nil { return nRunningTotal, err } for k, v := range drv.Env { nEnvKey, err := w.WriteString(k) nRunningTotal += nEnvKey if err != nil { return nRunningTotal, err } nEnvValue, err := w.WriteString(v) nRunningTotal += nEnvValue if err != nil { return nRunningTotal, err } } return nRunningTotal, nil } type Deserializer struct { io.Reader } func (r Deserializer) ReadPadding(n uint64) error { if n%8 == 0 { return nil } buf := make([]byte, 8-(int(n)%8)) if _, err := io.ReadFull(r, buf); err != nil { return fmt.Errorf("reading padding: %w", err) } for _, n := range buf { if n != 0 { return fmt.Errorf("padding was not all 0") } } return nil } func (r Deserializer) ReadUint64() (uint64, error) { buf := make([]byte, 8) if _, err := io.ReadFull(r, buf); err != nil { return 0, fmt.Errorf("reading uint64: %w", err) } return binary.LittleEndian.Uint64(buf), nil } func (r Deserializer) ReadBytes() ([]byte, error) { strLen, err := r.ReadUint64() if err != nil { return nil, fmt.Errorf("reading length: %w", err) } strBuf := make([]byte, int(strLen)) if _, err := io.ReadFull(r, strBuf); err != nil { return nil, fmt.Errorf("reading string buffer: %w", err) } if err := r.ReadPadding(strLen); err != nil { return nil, fmt.Errorf("reading string padding: %w", err) } return strBuf, nil } func (r Deserializer) ReadString() (string, error) { bs, err := r.ReadBytes() if err != nil { return "", err } return string(bs), nil } func (r Deserializer) ReadStrings() ([]string, error) { setLen, err := r.ReadUint64() if err != nil { return nil, fmt.Errorf("reading string set length: %w", err) } xs := make([]string, int(setLen)) for n := uint64(0); n < setLen; n++ { xs[n], err = r.ReadString() if err != nil { return nil, fmt.Errorf("reading string %d of %d: %w", n, setLen, err) } } return xs, nil } func (r Deserializer) ReadStringSet() (map[string]bool, error) { strs, err := r.ReadStrings() if err != nil { return nil, err } m := make(map[string]bool, len(strs)) for _, k := range strs { m[k] = true } return m, nil } func (r Deserializer) ReadDerivation() (*nixdrv.BasicDerivation, error) { drv := &nixdrv.BasicDerivation{} outputsLen, err := r.ReadUint64() if err != nil { return nil, fmt.Errorf("reading outputs length: %w", err) } drv.Outputs = make(map[string]nixdrv.Output) for n := uint64(0); n < outputsLen; n++ { outputName, err := r.ReadString() if err != nil { return nil, fmt.Errorf("reading output name: %w", err) } var o nixdrv.Output if o.Path, err = r.ReadString(); err != nil { return nil, fmt.Errorf("reading output path: %w", err) } if o.HashAlgorithm, err = r.ReadString(); err != nil { return nil, fmt.Errorf("reading output hash algorithm: %w", err) } if o.Hash, err = r.ReadString(); err != nil { return nil, fmt.Errorf("reading output hash: %w", err) } drv.Outputs[outputName] = o } if drv.InputSrcs, err = r.ReadStringSet(); err != nil { return nil, fmt.Errorf("reading input srcs: %w", err) } if drv.Platform, err = r.ReadString(); err != nil { return nil, fmt.Errorf("reading platform: %w", err) } if drv.Builder, err = r.ReadString(); err != nil { return nil, fmt.Errorf("reading builder: %w", err) } if drv.Args, err = r.ReadStrings(); err != nil { return nil, fmt.Errorf("reading args: %w", err) } envLen, err := r.ReadUint64() if err != nil { return nil, fmt.Errorf("reading environment length: %w", err) } drv.Env = map[string]string{} for n := uint64(0); n < envLen; n++ { key, err := r.ReadString() if err != nil { return nil, fmt.Errorf("reading environment variable name: %w", err) } value, err := r.ReadString() if err != nil { return nil, fmt.Errorf("reading environment variable value: %w", err) } drv.Env[key] = value } return drv, nil }