From d6118944faa7a23ba63d0d3135fecca4d512d30a Mon Sep 17 00:00:00 2001 From: Luke Granger-Brown Date: Sun, 9 Oct 2022 23:48:18 +0100 Subject: [PATCH] nixstore: add remotestore --- go/nix/nixstore/default.nix | 2 +- go/nix/nixstore/nixstore_test.go | 18 -- go/nix/nixstore/remotestore.go | 279 ++++++++++++++++++ go/nix/nixstore/remotestore_test.go | 17 ++ .../nixstore/{nixstore.go => sqlitestore.go} | 0 go/nix/nixstore/sqlitestore_test.go | 1 + go/nix/nixwire/nixwire.go | 65 ++++ go/nix/nixwire/nixwire_test.go | 31 ++ 8 files changed, 394 insertions(+), 19 deletions(-) delete mode 100644 go/nix/nixstore/nixstore_test.go create mode 100644 go/nix/nixstore/remotestore.go create mode 100644 go/nix/nixstore/remotestore_test.go rename go/nix/nixstore/{nixstore.go => sqlitestore.go} (100%) create mode 100644 go/nix/nixstore/sqlitestore_test.go create mode 100644 go/nix/nixwire/nixwire_test.go diff --git a/go/nix/nixstore/default.nix b/go/nix/nixstore/default.nix index d69a3a63fe..241d82c63b 100644 --- a/go/nix/nixstore/default.nix +++ b/go/nix/nixstore/default.nix @@ -7,7 +7,7 @@ depot.third_party.buildGo.package { name = "nixstore"; path = "hg.lukegb.com/lukegb/depot/go/nix/nixstore"; srcs = [ - ./nixstore.go + ./sqlitestore.go ]; deps = with depot; [ go.nix.nar.narinfo diff --git a/go/nix/nixstore/nixstore_test.go b/go/nix/nixstore/nixstore_test.go deleted file mode 100644 index 429ddf2730..0000000000 --- a/go/nix/nixstore/nixstore_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package nixstore - -import "testing" - -func TestNARInfo(t *testing.T) { - db, err := Open(DefaultStoreDB) - if err != nil { - t.Fatalf("Open: %v", err) - } - - ni, err := db.NARInfo("/nix/store/yk8ps7v1jhwpj82pigmqjb68ln7bgjbn-acl-2.3.1") - if err != nil { - t.Fatalf("NARInfo: %v", err) - } - t.Logf("%#v", ni) - t.Log(ni.String()) - t.Error("meep") -} diff --git a/go/nix/nixstore/remotestore.go b/go/nix/nixstore/remotestore.go new file mode 100644 index 0000000000..85bed91f86 --- /dev/null +++ b/go/nix/nixstore/remotestore.go @@ -0,0 +1,279 @@ +package nixstore + +import ( + "encoding/base64" + "fmt" + "log" + "net" + "path" + "strings" + + "hg.lukegb.com/lukegb/depot/go/nix/nar/narinfo" + "hg.lukegb.com/lukegb/depot/go/nix/nixwire" +) + +const ( + DaemonSock = "/nix/var/nix/daemon-socket/socket" + WorkerMagic1 = 0x6e697863 + WorkerMagic2 = 0x6478696f + ProtocolVersion = 0x115 + + WopQueryPathInfo = 26 + WopQuerySubstitutablePathInfo = 21 +) + +type Daemon struct { + conn net.Conn + w *nixwire.Serializer + r *nixwire.Deserializer +} + +func (d *Daemon) NARInfo(storePath string) (*narinfo.NarInfo, error) { + if _, err := d.w.WriteUint64(WopQueryPathInfo); err != nil { + return nil, fmt.Errorf("writing worker op WopQueryPathInfo: %w", err) + } + if _, err := d.w.WriteString(storePath); err != nil { + return nil, fmt.Errorf("writing store path query %v: %w", storePath, err) + } + + if err := d.processStderr(); err != nil { + return nil, fmt.Errorf("reading stderr from WopQueryPathInfo: %w", err) + } + + validInt, err := d.r.ReadUint64() + if err != nil { + return nil, fmt.Errorf("reading path validity: %w", err) + } + valid := validInt == uint64(1) + if !valid { + return nil, fmt.Errorf("path %q is not valid", storePath) + } + + ni := &narinfo.NarInfo{ + StorePath: storePath, + } + if ni.Deriver, err = d.r.ReadString(); err != nil { + return nil, fmt.Errorf("reading deriver: %w", err) + } + + hashStr, err := d.r.ReadString() + if err != nil { + return nil, fmt.Errorf("reading NAR hash: %w", err) + } + if ni.NarHash, err = narinfo.HashFromString("sha256:" + hashStr); err != nil { + return nil, fmt.Errorf("parsing NAR hash %q: %w", hashStr, err) + } + + refs, err := d.r.ReadStrings() + if err != nil { + return nil, fmt.Errorf("reading referrers: %w", err) + } + for n, ref := range refs { + refs[n] = path.Base(ref) + } + ni.References = refs + + if _, err := d.r.ReadUint64(); err != nil { + return nil, fmt.Errorf("reading registration time: %w", err) + } + if ni.NarSize, err = d.r.ReadUint64(); err != nil { + return nil, fmt.Errorf("reading narsize: %w", err) + } + if _, err := d.r.ReadUint64(); err != nil { + return nil, fmt.Errorf("reading ultimate: %w", err) + } + sigs, err := d.r.ReadStrings() + if err != nil { + return nil, fmt.Errorf("reading sigs: %w", err) + } + if _, err := d.r.ReadUint64(); err != nil { + return nil, fmt.Errorf("reading CA: %w", err) + } + + ni.Sig = make(map[string][]byte) + for _, sigStr := range sigs { + sigBits := strings.Split(sigStr, ":") + sigHash, err := base64.StdEncoding.DecodeString(sigBits[1]) + if err != nil { + return nil, fmt.Errorf("decoding signature %q: %w", sigStr, err) + } + ni.Sig[sigBits[0]] = sigHash + } + + return ni, nil +} + +func (d *Daemon) Close() error { + return d.conn.Close() +} + +func (d *Daemon) readFields() ([]any, error) { + fieldsSz, err := d.r.ReadUint64() + if err != nil { + return nil, fmt.Errorf("reading fields count: %w", err) + } + var fields []any + for n := 0; n < int(fieldsSz); n++ { + fieldTyp, err := d.r.ReadUint64() + if err != nil { + return nil, fmt.Errorf("reading field %d type: %w", n, err) + } + switch fieldTyp { + case 0: // tInt + fieldVal, err := d.r.ReadUint64() + if err != nil { + return nil, fmt.Errorf("reading field %d int: %w", n, err) + } + fields = append(fields, fieldVal) + case 1: // tString + fieldVal, err := d.r.ReadString() + if err != nil { + return nil, fmt.Errorf("reading field %d str: %w", n, err) + } + fields = append(fields, fieldVal) + default: + return nil, fmt.Errorf("reading field %d: unsupported field type %x", n, fieldTyp) + } + } + return fields, nil +} + +func (d *Daemon) processStderr() error { + for { + msg, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("read stderr msg code: %w", err) + } + switch msg { + case 0x64617416: // STDERR_WRITE + stderr, err := d.r.ReadString() + if err != nil { + return fmt.Errorf("read stderr msg: %w", err) + } + log.Println(stderr) + case 0x64617461: // STDERR_READ + return fmt.Errorf("STDERR_READ requested") + case 0x63787470: // STDERR_ERROR + errStr, err := d.r.ReadString() + if err != nil { + return fmt.Errorf("STDERR_ERROR reading error string: %w", err) + } + status, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("STDERR_ERROR reading uint64: %w", err) + } + return fmt.Errorf("error code %d: %v", status, errStr) + case 0x6f6c6d67: // STDERR_NEXT + errStr, err := d.r.ReadString() + if err != nil { + return fmt.Errorf("STDERR_NEXT reading error string: %w", err) + } + return fmt.Errorf("error: %v", errStr) + case 0x53545254: // STDERR_START_ACTIVITY + activity, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("STDERR_START_ACTIVITY reading ActivityId: %w", err) + } + lvl, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("STDERR_START_ACTIVITY reading level: %w", err) + } + actType, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("STDERR_START_ACTIVITY reading activity type: %w", err) + } + s, err := d.r.ReadString() + if err != nil { + return fmt.Errorf("STDERR_START_ACTIVITY reading s: %w", err) + } + fields, err := d.readFields() + if err != nil { + return fmt.Errorf("STDERR_START_ACTIVITY fields: %w", err) + } + parentActivity, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("STDERR_START_ACTIVITY reading parent activity: %w", err) + } + log.Printf("START_ACTIVITY activity=%d, lvl=%d, actType=%d, s=%s, fields=%v, parentActivity=%d", activity, lvl, actType, s, fields, parentActivity) + case 0x53544f50: // STDERR_STOP_ACTIVITY + activity, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("STDERR_STOP_ACTIVITY reading ActivityId: %w", err) + } + log.Printf("STOP_ACTIVITY activity=%d", activity) + case 0x52534c54: // STDERR_RESULT + activity, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("STDERR_RESULT reading ActivityId: %w", err) + } + resultTyp, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("STDERR_RESULT reading result: %w", err) + } + fields, err := d.readFields() + if err != nil { + return fmt.Errorf("STDERR_RESULT fields: %w", err) + } + log.Printf("activity RESULT activity=%d resultType=%d fields=%v", activity, resultTyp, fields) + case 0x616c7473: // STDERR_LAST + return nil + } + } +} + +func (d *Daemon) hello() error { + _, err := d.w.WriteUint64(WorkerMagic1) + if err != nil { + return fmt.Errorf("writing magic 1: %w", err) + } + + magic2, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("reading magic 2: %w", err) + } + if magic2 != WorkerMagic2 { + return fmt.Errorf("magic 2 mismatch: got %x, wanted %x", magic2, WorkerMagic2) + } + + daemonVersion, err := d.r.ReadUint64() + if err != nil { + return fmt.Errorf("reading daemon version: %w", err) + } + majorVersion := int((daemonVersion >> 8) & 0xff) + minorVersion := int(daemonVersion & 0xff) + if majorVersion != 1 { + return fmt.Errorf("daemon major version mismatch: got %d, want 1", majorVersion) + } + if minorVersion < 21 { + return fmt.Errorf("daemon minor version too old: got %d, want at least 21", minorVersion) + } + + if _, err := d.w.WriteUint64(ProtocolVersion); err != nil { + return fmt.Errorf("writing protocol version: %w", err) + } + + if _, err := d.w.WriteUint64(0); err != nil { + return fmt.Errorf("writing obsolete CPU affinity: %w", err) + } + if _, err := d.w.WriteUint64(0); err != nil { + return fmt.Errorf("writing obsolete reserveSpace: %w", err) + } + + return d.processStderr() +} + +func OpenDaemon(path string) (*Daemon, error) { + conn, err := net.Dial("unix", path) + if err != nil { + return nil, fmt.Errorf("dialing %v: %w", path, err) + } + + d := &Daemon{conn: conn, w: &nixwire.Serializer{conn}, r: &nixwire.Deserializer{conn}} + if err := d.hello(); err != nil { + d.Close() + return nil, fmt.Errorf("sending hello to %v: %w", path, err) + } + + // /nix/var/nix/daemon-socket/socket + return d, nil +} diff --git a/go/nix/nixstore/remotestore_test.go b/go/nix/nixstore/remotestore_test.go new file mode 100644 index 0000000000..da7feb43b3 --- /dev/null +++ b/go/nix/nixstore/remotestore_test.go @@ -0,0 +1,17 @@ +package nixstore + +import "testing" + +func TestRemoteStore(t *testing.T) { + d, err := OpenDaemon(DaemonSock) + if err != nil { + t.Fatalf("OpenDaemon: %v", err) + } + defer d.Close() + + ni, err := d.NARInfo("/nix/store/whdqnq4nqxx8idwiv36rjnfa9rld3xfq-nix-2.3.16") + if err != nil { + t.Fatalf("NARInfo: %v", err) + } + t.Log(ni) +} diff --git a/go/nix/nixstore/nixstore.go b/go/nix/nixstore/sqlitestore.go similarity index 100% rename from go/nix/nixstore/nixstore.go rename to go/nix/nixstore/sqlitestore.go diff --git a/go/nix/nixstore/sqlitestore_test.go b/go/nix/nixstore/sqlitestore_test.go new file mode 100644 index 0000000000..37a10edb15 --- /dev/null +++ b/go/nix/nixstore/sqlitestore_test.go @@ -0,0 +1 @@ +package nixstore diff --git a/go/nix/nixwire/nixwire.go b/go/nix/nixwire/nixwire.go index df77c52650..a0e7eb83c1 100644 --- a/go/nix/nixwire/nixwire.go +++ b/go/nix/nixwire/nixwire.go @@ -2,6 +2,7 @@ package nixwire import ( "encoding/binary" + "fmt" "io" ) @@ -38,3 +39,67 @@ func (w Serializer) WriteString(s string) (int64, error) { nPad, err := w.WritePadding(int64(len(s))) return int64(nSize) + int64(nData) + int64(nPad), err } + +type Deserializer struct { + io.Reader +} + +func (r Deserializer) ReadPadding(n uint64) error { + if n%8 == 0 { + return nil + } + + buf := make([]byte, 8-(int(n)%8)) + if _, err := io.ReadFull(r, buf); err != nil { + return fmt.Errorf("reading padding: %w", err) + } + for _, n := range buf { + if n != 0 { + return fmt.Errorf("padding was not all 0") + } + } + return nil +} + +func (r Deserializer) ReadUint64() (uint64, error) { + buf := make([]byte, 8) + if _, err := io.ReadFull(r, buf); err != nil { + return 0, fmt.Errorf("reading uint64: %w", err) + } + + return binary.LittleEndian.Uint64(buf), nil +} + +func (r Deserializer) ReadString() (string, error) { + strLen, err := r.ReadUint64() + if err != nil { + return "", fmt.Errorf("reading length: %w", err) + } + + strBuf := make([]byte, int(strLen)) + if _, err := io.ReadFull(r, strBuf); err != nil { + return "", fmt.Errorf("reading string buffer: %w", err) + } + + if err := r.ReadPadding(strLen); err != nil { + return "", fmt.Errorf("reading string padding: %w", err) + } + + return string(strBuf), nil +} + +func (r Deserializer) ReadStrings() ([]string, error) { + setLen, err := r.ReadUint64() + if err != nil { + return nil, fmt.Errorf("reading string set length: %w", err) + } + + xs := make([]string, int(setLen)) + for n := uint64(0); n < setLen; n++ { + xs[n], err = r.ReadString() + if err != nil { + return nil, fmt.Errorf("reading string %d of %d: %w", n, setLen, err) + } + } + return xs, nil +} diff --git a/go/nix/nixwire/nixwire_test.go b/go/nix/nixwire/nixwire_test.go new file mode 100644 index 0000000000..b02ce01010 --- /dev/null +++ b/go/nix/nixwire/nixwire_test.go @@ -0,0 +1,31 @@ +package nixwire + +import ( + "bytes" + "testing" +) + +func TestRoundtrip(t *testing.T) { + var buf bytes.Buffer + + r := Deserializer{&buf} + w := Serializer{&buf} + + for n := 0; n < 128; n++ { + str := make([]byte, n) + for i := 0; i < len(str); i++ { + str[i] = 'a' + } + + if _, err := w.WriteString(string(str)); err != nil { + t.Fatalf("WriteString(%d): %v", n, err) + } + got, err := r.ReadString() + if err != nil { + t.Fatalf("ReadString(%d): %v", n, err) + } + if got != string(str) { + t.Errorf("ReadString != WriteString (%q != %q)", got, string(str)) + } + } +}