depot/go/nix/nixstore/remotestore.go

292 lines
7.9 KiB
Go

package nixstore
import (
"encoding/base64"
"fmt"
"log"
"net"
"path"
"strings"
"sync"
"hg.lukegb.com/lukegb/depot/go/nix/nar/narinfo"
"hg.lukegb.com/lukegb/depot/go/nix/nixwire"
)
const (
DaemonSock = "/nix/var/nix/daemon-socket/socket"
WorkerMagic1 = 0x6e697863
WorkerMagic2 = 0x6478696f
ProtocolVersion = 0x115
WopQueryPathInfo = 26
WopQuerySubstitutablePathInfo = 21
)
type Daemon struct {
conn net.Conn
w *nixwire.Serializer
r *nixwire.Deserializer
mu sync.Mutex
err error
}
func (d *Daemon) NARInfo(storePath string) (*narinfo.NarInfo, error) {
d.mu.Lock()
defer d.mu.Unlock()
if _, err := d.w.WriteUint64(WopQueryPathInfo); err != nil {
return nil, fmt.Errorf("writing worker op WopQueryPathInfo: %w", err)
}
if _, err := d.w.WriteString(storePath); err != nil {
return nil, fmt.Errorf("writing store path query %v: %w", storePath, err)
}
if err := d.processStderr(); err != nil {
return nil, fmt.Errorf("reading stderr from WopQueryPathInfo: %w", err)
}
validInt, err := d.r.ReadUint64()
if err != nil {
return nil, fmt.Errorf("reading path validity: %w", err)
}
valid := validInt == uint64(1)
if !valid {
return nil, fmt.Errorf("path %q is not valid", storePath)
}
ni := &narinfo.NarInfo{
StorePath: storePath,
}
if ni.Deriver, err = d.r.ReadString(); err != nil {
return nil, fmt.Errorf("reading deriver: %w", err)
}
ni.Deriver = path.Base(ni.Deriver)
if ni.Deriver == "." {
ni.Deriver = ""
}
hashStr, err := d.r.ReadString()
if err != nil {
return nil, fmt.Errorf("reading NAR hash: %w", err)
}
if ni.NarHash, err = narinfo.HashFromString("sha256:" + hashStr); err != nil {
return nil, fmt.Errorf("parsing NAR hash %q: %w", hashStr, err)
}
refs, err := d.r.ReadStrings()
if err != nil {
return nil, fmt.Errorf("reading referrers: %w", err)
}
for n, ref := range refs {
refs[n] = path.Base(ref)
}
ni.References = refs
if _, err := d.r.ReadUint64(); err != nil {
return nil, fmt.Errorf("reading registration time: %w", err)
}
if ni.NarSize, err = d.r.ReadUint64(); err != nil {
return nil, fmt.Errorf("reading narsize: %w", err)
}
if _, err := d.r.ReadUint64(); err != nil {
return nil, fmt.Errorf("reading ultimate: %w", err)
}
sigs, err := d.r.ReadStrings()
if err != nil {
return nil, fmt.Errorf("reading sigs: %w", err)
}
if _, err := d.r.ReadUint64(); err != nil {
return nil, fmt.Errorf("reading CA: %w", err)
}
ni.Sig = make(map[string][]byte)
for _, sigStr := range sigs {
sigBits := strings.Split(sigStr, ":")
sigHash, err := base64.StdEncoding.DecodeString(sigBits[1])
if err != nil {
return nil, fmt.Errorf("decoding signature %q: %w", sigStr, err)
}
ni.Sig[sigBits[0]] = sigHash
}
return ni, nil
}
func (d *Daemon) Close() error {
d.mu.Lock()
defer d.mu.Unlock()
return d.conn.Close()
}
func (d *Daemon) readFields() ([]any, error) {
fieldsSz, err := d.r.ReadUint64()
if err != nil {
return nil, fmt.Errorf("reading fields count: %w", err)
}
var fields []any
for n := 0; n < int(fieldsSz); n++ {
fieldTyp, err := d.r.ReadUint64()
if err != nil {
return nil, fmt.Errorf("reading field %d type: %w", n, err)
}
switch fieldTyp {
case 0: // tInt
fieldVal, err := d.r.ReadUint64()
if err != nil {
return nil, fmt.Errorf("reading field %d int: %w", n, err)
}
fields = append(fields, fieldVal)
case 1: // tString
fieldVal, err := d.r.ReadString()
if err != nil {
return nil, fmt.Errorf("reading field %d str: %w", n, err)
}
fields = append(fields, fieldVal)
default:
return nil, fmt.Errorf("reading field %d: unsupported field type %x", n, fieldTyp)
}
}
return fields, nil
}
func (d *Daemon) processStderr() error {
for {
msg, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("read stderr msg code: %w", err)
}
switch msg {
case 0x64617416: // STDERR_WRITE
stderr, err := d.r.ReadString()
if err != nil {
return fmt.Errorf("read stderr msg: %w", err)
}
log.Println(stderr)
case 0x64617461: // STDERR_READ
return fmt.Errorf("STDERR_READ requested")
case 0x63787470: // STDERR_ERROR
errStr, err := d.r.ReadString()
if err != nil {
return fmt.Errorf("STDERR_ERROR reading error string: %w", err)
}
status, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("STDERR_ERROR reading uint64: %w", err)
}
return fmt.Errorf("error code %d: %v", status, errStr)
case 0x6f6c6d67: // STDERR_NEXT
errStr, err := d.r.ReadString()
if err != nil {
return fmt.Errorf("STDERR_NEXT reading error string: %w", err)
}
return fmt.Errorf("error: %v", errStr)
case 0x53545254: // STDERR_START_ACTIVITY
activity, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("STDERR_START_ACTIVITY reading ActivityId: %w", err)
}
lvl, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("STDERR_START_ACTIVITY reading level: %w", err)
}
actType, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("STDERR_START_ACTIVITY reading activity type: %w", err)
}
s, err := d.r.ReadString()
if err != nil {
return fmt.Errorf("STDERR_START_ACTIVITY reading s: %w", err)
}
fields, err := d.readFields()
if err != nil {
return fmt.Errorf("STDERR_START_ACTIVITY fields: %w", err)
}
parentActivity, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("STDERR_START_ACTIVITY reading parent activity: %w", err)
}
log.Printf("START_ACTIVITY activity=%d, lvl=%d, actType=%d, s=%s, fields=%v, parentActivity=%d", activity, lvl, actType, s, fields, parentActivity)
case 0x53544f50: // STDERR_STOP_ACTIVITY
activity, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("STDERR_STOP_ACTIVITY reading ActivityId: %w", err)
}
log.Printf("STOP_ACTIVITY activity=%d", activity)
case 0x52534c54: // STDERR_RESULT
activity, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("STDERR_RESULT reading ActivityId: %w", err)
}
resultTyp, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("STDERR_RESULT reading result: %w", err)
}
fields, err := d.readFields()
if err != nil {
return fmt.Errorf("STDERR_RESULT fields: %w", err)
}
log.Printf("activity RESULT activity=%d resultType=%d fields=%v", activity, resultTyp, fields)
case 0x616c7473: // STDERR_LAST
return nil
}
}
}
func (d *Daemon) hello() error {
_, err := d.w.WriteUint64(WorkerMagic1)
if err != nil {
return fmt.Errorf("writing magic 1: %w", err)
}
magic2, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("reading magic 2: %w", err)
}
if magic2 != WorkerMagic2 {
return fmt.Errorf("magic 2 mismatch: got %x, wanted %x", magic2, WorkerMagic2)
}
daemonVersion, err := d.r.ReadUint64()
if err != nil {
return fmt.Errorf("reading daemon version: %w", err)
}
majorVersion := int((daemonVersion >> 8) & 0xff)
minorVersion := int(daemonVersion & 0xff)
if majorVersion != 1 {
return fmt.Errorf("daemon major version mismatch: got %d, want 1", majorVersion)
}
if minorVersion < 21 {
return fmt.Errorf("daemon minor version too old: got %d, want at least 21", minorVersion)
}
if _, err := d.w.WriteUint64(ProtocolVersion); err != nil {
return fmt.Errorf("writing protocol version: %w", err)
}
if _, err := d.w.WriteUint64(0); err != nil {
return fmt.Errorf("writing obsolete CPU affinity: %w", err)
}
if _, err := d.w.WriteUint64(0); err != nil {
return fmt.Errorf("writing obsolete reserveSpace: %w", err)
}
return d.processStderr()
}
func OpenDaemon(path string) (*Daemon, error) {
conn, err := net.Dial("unix", path)
if err != nil {
return nil, fmt.Errorf("dialing %v: %w", path, err)
}
d := &Daemon{conn: conn, w: &nixwire.Serializer{conn}, r: &nixwire.Deserializer{conn}}
if err := d.hello(); err != nil {
d.Close()
return nil, fmt.Errorf("sending hello to %v: %w", path, err)
}
// /nix/var/nix/daemon-socket/socket
return d, nil
}