package nixstore import ( "encoding/base64" "fmt" "log" "net" "path" "strings" "sync" "hg.lukegb.com/lukegb/depot/go/nix/nar/narinfo" "hg.lukegb.com/lukegb/depot/go/nix/nixwire" ) const ( DaemonSock = "/nix/var/nix/daemon-socket/socket" WorkerMagic1 = 0x6e697863 WorkerMagic2 = 0x6478696f ProtocolVersion = 0x115 WopQueryPathInfo = 26 WopQuerySubstitutablePathInfo = 21 ) type Daemon struct { conn net.Conn w *nixwire.Serializer r *nixwire.Deserializer mu sync.Mutex err error } func (d *Daemon) NARInfo(storePath string) (*narinfo.NarInfo, error) { d.mu.Lock() defer d.mu.Unlock() if _, err := d.w.WriteUint64(WopQueryPathInfo); err != nil { return nil, fmt.Errorf("writing worker op WopQueryPathInfo: %w", err) } if _, err := d.w.WriteString(storePath); err != nil { return nil, fmt.Errorf("writing store path query %v: %w", storePath, err) } if err := d.processStderr(); err != nil { return nil, fmt.Errorf("reading stderr from WopQueryPathInfo: %w", err) } validInt, err := d.r.ReadUint64() if err != nil { return nil, fmt.Errorf("reading path validity: %w", err) } valid := validInt == uint64(1) if !valid { return nil, fmt.Errorf("path %q is not valid", storePath) } ni := &narinfo.NarInfo{ StorePath: storePath, } if ni.Deriver, err = d.r.ReadString(); err != nil { return nil, fmt.Errorf("reading deriver: %w", err) } ni.Deriver = path.Base(ni.Deriver) hashStr, err := d.r.ReadString() if err != nil { return nil, fmt.Errorf("reading NAR hash: %w", err) } if ni.NarHash, err = narinfo.HashFromString("sha256:" + hashStr); err != nil { return nil, fmt.Errorf("parsing NAR hash %q: %w", hashStr, err) } refs, err := d.r.ReadStrings() if err != nil { return nil, fmt.Errorf("reading referrers: %w", err) } for n, ref := range refs { refs[n] = path.Base(ref) } ni.References = refs if _, err := d.r.ReadUint64(); err != nil { return nil, fmt.Errorf("reading registration time: %w", err) } if ni.NarSize, err = d.r.ReadUint64(); err != nil { return nil, fmt.Errorf("reading narsize: %w", err) } if _, err := d.r.ReadUint64(); err != nil { return nil, fmt.Errorf("reading ultimate: %w", err) } sigs, err := d.r.ReadStrings() if err != nil { return nil, fmt.Errorf("reading sigs: %w", err) } if _, err := d.r.ReadUint64(); err != nil { return nil, fmt.Errorf("reading CA: %w", err) } ni.Sig = make(map[string][]byte) for _, sigStr := range sigs { sigBits := strings.Split(sigStr, ":") sigHash, err := base64.StdEncoding.DecodeString(sigBits[1]) if err != nil { return nil, fmt.Errorf("decoding signature %q: %w", sigStr, err) } ni.Sig[sigBits[0]] = sigHash } return ni, nil } func (d *Daemon) Close() error { d.mu.Lock() defer d.mu.Unlock() return d.conn.Close() } func (d *Daemon) readFields() ([]any, error) { fieldsSz, err := d.r.ReadUint64() if err != nil { return nil, fmt.Errorf("reading fields count: %w", err) } var fields []any for n := 0; n < int(fieldsSz); n++ { fieldTyp, err := d.r.ReadUint64() if err != nil { return nil, fmt.Errorf("reading field %d type: %w", n, err) } switch fieldTyp { case 0: // tInt fieldVal, err := d.r.ReadUint64() if err != nil { return nil, fmt.Errorf("reading field %d int: %w", n, err) } fields = append(fields, fieldVal) case 1: // tString fieldVal, err := d.r.ReadString() if err != nil { return nil, fmt.Errorf("reading field %d str: %w", n, err) } fields = append(fields, fieldVal) default: return nil, fmt.Errorf("reading field %d: unsupported field type %x", n, fieldTyp) } } return fields, nil } func (d *Daemon) processStderr() error { for { msg, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("read stderr msg code: %w", err) } switch msg { case 0x64617416: // STDERR_WRITE stderr, err := d.r.ReadString() if err != nil { return fmt.Errorf("read stderr msg: %w", err) } log.Println(stderr) case 0x64617461: // STDERR_READ return fmt.Errorf("STDERR_READ requested") case 0x63787470: // STDERR_ERROR errStr, err := d.r.ReadString() if err != nil { return fmt.Errorf("STDERR_ERROR reading error string: %w", err) } status, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("STDERR_ERROR reading uint64: %w", err) } return fmt.Errorf("error code %d: %v", status, errStr) case 0x6f6c6d67: // STDERR_NEXT errStr, err := d.r.ReadString() if err != nil { return fmt.Errorf("STDERR_NEXT reading error string: %w", err) } return fmt.Errorf("error: %v", errStr) case 0x53545254: // STDERR_START_ACTIVITY activity, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("STDERR_START_ACTIVITY reading ActivityId: %w", err) } lvl, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("STDERR_START_ACTIVITY reading level: %w", err) } actType, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("STDERR_START_ACTIVITY reading activity type: %w", err) } s, err := d.r.ReadString() if err != nil { return fmt.Errorf("STDERR_START_ACTIVITY reading s: %w", err) } fields, err := d.readFields() if err != nil { return fmt.Errorf("STDERR_START_ACTIVITY fields: %w", err) } parentActivity, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("STDERR_START_ACTIVITY reading parent activity: %w", err) } log.Printf("START_ACTIVITY activity=%d, lvl=%d, actType=%d, s=%s, fields=%v, parentActivity=%d", activity, lvl, actType, s, fields, parentActivity) case 0x53544f50: // STDERR_STOP_ACTIVITY activity, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("STDERR_STOP_ACTIVITY reading ActivityId: %w", err) } log.Printf("STOP_ACTIVITY activity=%d", activity) case 0x52534c54: // STDERR_RESULT activity, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("STDERR_RESULT reading ActivityId: %w", err) } resultTyp, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("STDERR_RESULT reading result: %w", err) } fields, err := d.readFields() if err != nil { return fmt.Errorf("STDERR_RESULT fields: %w", err) } log.Printf("activity RESULT activity=%d resultType=%d fields=%v", activity, resultTyp, fields) case 0x616c7473: // STDERR_LAST return nil } } } func (d *Daemon) hello() error { _, err := d.w.WriteUint64(WorkerMagic1) if err != nil { return fmt.Errorf("writing magic 1: %w", err) } magic2, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("reading magic 2: %w", err) } if magic2 != WorkerMagic2 { return fmt.Errorf("magic 2 mismatch: got %x, wanted %x", magic2, WorkerMagic2) } daemonVersion, err := d.r.ReadUint64() if err != nil { return fmt.Errorf("reading daemon version: %w", err) } majorVersion := int((daemonVersion >> 8) & 0xff) minorVersion := int(daemonVersion & 0xff) if majorVersion != 1 { return fmt.Errorf("daemon major version mismatch: got %d, want 1", majorVersion) } if minorVersion < 21 { return fmt.Errorf("daemon minor version too old: got %d, want at least 21", minorVersion) } if _, err := d.w.WriteUint64(ProtocolVersion); err != nil { return fmt.Errorf("writing protocol version: %w", err) } if _, err := d.w.WriteUint64(0); err != nil { return fmt.Errorf("writing obsolete CPU affinity: %w", err) } if _, err := d.w.WriteUint64(0); err != nil { return fmt.Errorf("writing obsolete reserveSpace: %w", err) } return d.processStderr() } func OpenDaemon(path string) (*Daemon, error) { conn, err := net.Dial("unix", path) if err != nil { return nil, fmt.Errorf("dialing %v: %w", path, err) } d := &Daemon{conn: conn, w: &nixwire.Serializer{conn}, r: &nixwire.Deserializer{conn}} if err := d.hello(); err != nil { d.Close() return nil, fmt.Errorf("sending hello to %v: %w", path, err) } // /nix/var/nix/daemon-socket/socket return d, nil }