diff --git a/Makefile b/Makefile index 89d454d..d6091ac 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ IMG ?= lfedge/adam HASH ?= $(shell git show --format=%T -s) -GOVER ?= 1.16.3-alpine3.13 +GOVER ?= 1.20.10-alpine3.18 # check if we should append a dirty tag @@ -68,7 +68,7 @@ $(LOCALBIN): $(LOCALLINK): @if [ "$(OS)" = "$(BUILDOS)" -a "$(ARCH)" = "$(BUILDARCH)" -a ! -L "$@" -a ! -e "$@" ]; then ln -s $(notdir $(LOCALBIN)) $@; fi -build-docker: +build-docker: docker build -t $(IMG) . build-docker-local: build diff --git a/cmd/admin.go b/cmd/admin.go index dcad851..77d4fa3 100644 --- a/cmd/admin.go +++ b/cmd/admin.go @@ -6,10 +6,10 @@ package cmd import ( "crypto/tls" "crypto/x509" - "io/ioutil" "log" "net/http" "net/url" + "os" "path" "time" @@ -69,7 +69,7 @@ func getStreamingClient() *http.Client { func getClientStreamingOption(stream bool) *http.Client { tlsConfig := &tls.Config{} if serverCA != "" { - caCert, err := ioutil.ReadFile(serverCA) + caCert, err := os.ReadFile(serverCA) if err != nil { log.Fatalf("unable to read server CA file at %s: %v", serverCA, err) } diff --git a/cmd/device.go b/cmd/device.go index 2c47fde..8115064 100644 --- a/cmd/device.go +++ b/cmd/device.go @@ -8,7 +8,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "log" "net/http" "os" @@ -44,7 +43,7 @@ var deviceListCmd = &cobra.Command{ if err != nil { log.Fatalf("error reading URL %s: %v", u, err) } - buf, err := ioutil.ReadAll(response.Body) + buf, err := io.ReadAll(response.Body) if err != nil { log.Fatalf("unable to read data from URL %s: %v", u, err) } @@ -66,7 +65,7 @@ var deviceGetCmd = &cobra.Command{ if err != nil { log.Fatalf("error reading URL %s: %v", u, err) } - buf, err := ioutil.ReadAll(response.Body) + buf, err := io.ReadAll(response.Body) if err != nil { log.Fatalf("unable to read data from URL %s: %v", u, err) } @@ -81,7 +80,7 @@ var deviceAddCmd = &cobra.Command{ Short: "add new device", Long: `Add new device and retrieve the UUID`, Run: func(cmd *cobra.Command, args []string) { - b, err := ioutil.ReadFile(certPath) + b, err := os.ReadFile(certPath) switch { case err != nil && os.IsNotExist(err): log.Fatalf("cert file %s does not exist", certPath) @@ -172,7 +171,7 @@ var deviceConfigGetCmd = &cobra.Command{ if err != nil { log.Fatalf("error reading URL %s: %v", u, err) } - buf, err := ioutil.ReadAll(response.Body) + buf, err := io.ReadAll(response.Body) if err != nil { log.Fatalf("unable to read data from URL %s: %v", u, err) } @@ -192,12 +191,12 @@ var deviceConfigSetCmd = &cobra.Command{ err error ) if configPath == "-" { - b, err = ioutil.ReadAll(os.Stdin) + b, err = io.ReadAll(os.Stdin) if err != nil && err != io.EOF { log.Fatalf("Error reading stdin: %v", err) } } else { - b, err = ioutil.ReadFile(configPath) + b, err = os.ReadFile(configPath) switch { case err != nil && os.IsNotExist(err): log.Fatalf("config file %s does not exist", configPath) @@ -219,7 +218,7 @@ var deviceConfigSetCmd = &cobra.Command{ log.Fatalf("error PUT URL %s: %v", u, err) } if res.StatusCode != 200 { - b, _ := ioutil.ReadAll(res.Body) + b, _ := io.ReadAll(res.Body) log.Fatalf("error PUT URL %s: %d %s", u, res.StatusCode, string(b)) } }, diff --git a/cmd/onboard.go b/cmd/onboard.go index 13f8934..1ae667b 100644 --- a/cmd/onboard.go +++ b/cmd/onboard.go @@ -7,7 +7,7 @@ import ( "bytes" "encoding/json" "fmt" - "io/ioutil" + "io" "log" "net/http" "os" @@ -42,7 +42,7 @@ var onboardListCmd = &cobra.Command{ if err != nil { log.Fatalf("error reading URL %s: %v", u, err) } - buf, err := ioutil.ReadAll(response.Body) + buf, err := io.ReadAll(response.Body) if err != nil { log.Fatalf("unable to read data from URL %s: %v", u, err) } @@ -55,7 +55,7 @@ var onboardAddCmd = &cobra.Command{ Short: "add new onboarding certificate", Long: `Add new onboarding certificate, as well as the valid serials. If the certificate already exists, its serials are replaced by the provided list`, Run: func(cmd *cobra.Command, args []string) { - b, err := ioutil.ReadFile(certPath) + b, err := os.ReadFile(certPath) switch { case err != nil && os.IsNotExist(err): log.Fatalf("cert file %s does not exist", certPath) @@ -95,7 +95,7 @@ var onboardGetCmd = &cobra.Command{ if err != nil { log.Fatalf("error reading URL %s: %v", u, err) } - buf, err := ioutil.ReadAll(response.Body) + buf, err := io.ReadAll(response.Body) if err != nil { log.Fatalf("unable to read data from URL %s: %v", u, err) } diff --git a/cmd/server.go b/cmd/server.go index 0df2ad4..c5cfbd2 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -7,7 +7,6 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "io/ioutil" "log" "os" "path" @@ -115,11 +114,19 @@ var serverCmd = &cobra.Command{ if err != nil { log.Fatalf("error loading server cert and key from environment variables: %v", err) } - if err = ioutil.WriteFile(serverCert, []byte(serverENVCert), 0644); err != nil { - log.Fatal(err) - } - if err = ioutil.WriteFile(serverKey, []byte(serverENVKey), 0600); err != nil { - log.Fatal(err) + + // only create new certs and keys if the files do not exist + _, err = os.Stat(serverCert) + serverCertNotExist := os.IsNotExist(err) + _, err = os.Stat(serverKey) + serverKeyNotExist := os.IsNotExist(err) + if serverCertNotExist && serverKeyNotExist { + if err = os.WriteFile(serverCert, []byte(serverENVCert), 0644); err != nil { + log.Fatal(err) + } + if err = os.WriteFile(serverKey, []byte(serverENVKey), 0600); err != nil { + log.Fatal(err) + } } } else { // if we were asked to autoCert, then we do it @@ -145,34 +152,42 @@ var serverCmd = &cobra.Command{ log.Fatalf("error parsing server cert: %v", err) } - err = ioutil.WriteFile(path.Join(configDir, "server"), []byte(ca.Subject.CommonName+":"+port), 0644) + err = os.WriteFile(path.Join(configDir, "server"), []byte(ca.Subject.CommonName+":"+port), 0644) if err != nil { log.Fatalf("error writing to server file: %v", err) } - err = ioutil.WriteFile(path.Join(configDir, "hosts"), []byte(hostIP+" "+ca.Subject.CommonName), 0644) + err = os.WriteFile(path.Join(configDir, "hosts"), []byte(hostIP+" "+ca.Subject.CommonName), 0644) if err != nil { log.Fatalf("error writing hosts file: %v", err) } - rootCert, err := ioutil.ReadFile(serverCert) + rootCert, err := os.ReadFile(serverCert) if err != nil { log.Fatalf("error reading %s file: %v", serverCert, err) } - err = ioutil.WriteFile(path.Join(configDir, "root-certificate.pem"), rootCert, 0644) + err = os.WriteFile(path.Join(configDir, "root-certificate.pem"), rootCert, 0644) if err != nil { log.Fatalf("error writing root-certificate.pem file: %v", err) } + if signingENVCertProvided && signingENVKeyProvided { - _, err = tls.X509KeyPair([]byte(signingENVCert), []byte(signingENVKey)) - if err != nil { - log.Fatalf("error loading signing cert and key from environment variables: %v", err) - } - if err = ioutil.WriteFile(signingCert, []byte(signingENVCert), 0644); err != nil { - log.Fatal(err) - } - if err = ioutil.WriteFile(signingKey, []byte(signingENVKey), 0600); err != nil { - log.Fatal(err) + // only create new certs and keys if the files do not exist + _, err = os.Stat(signingCert) + signingCertNotExist := os.IsNotExist(err) + _, err = os.Stat(signingKey) + signingKeyNotExist := os.IsNotExist(err) + if signingCertNotExist && signingKeyNotExist { + _, err = tls.X509KeyPair([]byte(signingENVCert), []byte(signingENVKey)) + if err != nil { + log.Fatalf("error loading signing cert and key from environment variables: %v", err) + } + if err = os.WriteFile(signingCert, []byte(signingENVCert), 0644); err != nil { + log.Fatal(err) + } + if err = os.WriteFile(signingKey, []byte(signingENVKey), 0600); err != nil { + log.Fatal(err) + } } } else { // if we were asked to autoCert, then we do it @@ -193,16 +208,24 @@ var serverCmd = &cobra.Command{ log.Printf("Will use APIv1: error loading signing cert %s and signing key %s: %v", signingCert, signingKey, err) } } + if encryptENVCertProvided && encryptENVKeyProvided { - _, err = tls.X509KeyPair([]byte(encryptENVCert), []byte(encryptENVKey)) - if err != nil { - log.Fatalf("error loading encrypt cert and key from environment variables: %v", err) - } - if err = ioutil.WriteFile(encryptCert, []byte(encryptENVCert), 0644); err != nil { - log.Fatal(err) - } - if err = ioutil.WriteFile(encryptKey, []byte(encryptENVKey), 0600); err != nil { - log.Fatal(err) + // only create new certs and keys if the files do not exist + _, err = os.Stat(encryptCert) + encryptCertNotExist := os.IsNotExist(err) + _, err = os.Stat(encryptKey) + encryptKeyNotExist := os.IsNotExist(err) + if encryptCertNotExist && encryptKeyNotExist { + _, err = tls.X509KeyPair([]byte(encryptENVCert), []byte(encryptENVKey)) + if err != nil { + log.Fatalf("error loading encrypt cert and key from environment variables: %v", err) + } + if err = os.WriteFile(encryptCert, []byte(encryptENVCert), 0644); err != nil { + log.Fatal(err) + } + if err = os.WriteFile(encryptKey, []byte(encryptENVKey), 0600); err != nil { + log.Fatal(err) + } } } else { // if we were asked to autoCert, then we do it diff --git a/docs/admin.md b/docs/admin.md index 9bab5c9..e3a143e 100644 --- a/docs/admin.md +++ b/docs/admin.md @@ -29,6 +29,7 @@ The following are the admin endpoints: * `PUT /device/{uuid}/options` - update options for one device * `GET /options` - set global options * `PUT /options` - update global options +* `POST /certs` - update signing certificate ## Adam Admin diff --git a/pkg/driver/device_managers_test.go b/pkg/driver/device_managers_test.go index 983f335..52407ad 100644 --- a/pkg/driver/device_managers_test.go +++ b/pkg/driver/device_managers_test.go @@ -1,7 +1,6 @@ package driver_test import ( - "io/ioutil" "os" "path" "testing" @@ -27,7 +26,7 @@ func TestURLs(t *testing.T) { } // create a temporary working dir, because the file driver actually creates the directories - tmpdir, err := ioutil.TempDir("", "adam-driver-test") + tmpdir, err := os.MkdirTemp("", "adam-driver-test") if err != nil { t.Fatalf("could not create temporary directory: %v", err) } diff --git a/pkg/driver/file/device_manager_file.go b/pkg/driver/file/device_manager_file.go index a347609..cf61c69 100644 --- a/pkg/driver/file/device_manager_file.go +++ b/pkg/driver/file/device_manager_file.go @@ -10,7 +10,6 @@ import ( "encoding/pem" "errors" "fmt" - "io/ioutil" "net/url" "os" "path" @@ -99,7 +98,7 @@ func (m *ManagedFile) Write(b []byte) (int, error) { if m.totalSize > m.maxSize { // get all of the files from the directory - fi, err := ioutil.ReadDir(m.dir) + fi, err := os.ReadDir(m.dir) if err != nil { return written, fmt.Errorf("could not read directory %s: %v", m.dir, err) } @@ -111,8 +110,12 @@ func (m *ManagedFile) Write(b []byte) (int, error) { if m.totalSize < m.maxSize { break } - size := f.Size() filename := path.Join(m.dir, f.Name()) + fileInfo, err := os.Stat(filename) + if err != nil { + return written, fmt.Errorf("could not get file info for %s: %v", filename, err) + } + size := fileInfo.Size() if err := os.Remove(filename); err != nil { return written, fmt.Errorf("failed to remove %s: %v", filename, err) } @@ -308,7 +311,7 @@ func (d *DeviceManager) OnboardGet(cn string) (*x509.Certificate, []string, erro return nil, nil, fmt.Errorf("error reading onboard certificate at %s: %v", certPath, err) } serialPath := path.Join(onboardDir, onboardCertSerials) - serial, err := ioutil.ReadFile(serialPath) + serial, err := os.ReadFile(serialPath) if err != nil { return nil, nil, fmt.Errorf("error reading onboard serials at %s: %v", serialPath, err) } @@ -359,7 +362,7 @@ func (d *DeviceManager) OnboardRemove(cn string) error { func (d *DeviceManager) OnboardClear() error { // remove the directory and clear the cache onboardPath := path.Join(d.databasePath, onboardDir) - candidates, err := ioutil.ReadDir(onboardPath) + candidates, err := os.ReadDir(onboardPath) if err != nil { return fmt.Errorf("unable to read onboarding certificates at %s: %v", onboardPath, err) } @@ -390,8 +393,7 @@ func (d *DeviceManager) DeviceCheckCert(cert *x509.Certificate) (*uuid.UUID, err if err != nil { return nil, fmt.Errorf("unable to refresh certs from filesystem: %v", err) } - certStr := string(cert.Raw) - if u, ok := d.deviceCerts[certStr]; ok { + if u, ok := d.deviceCerts[string(cert.Raw)]; ok { return &u, nil } return nil, nil @@ -440,7 +442,7 @@ func (d *DeviceManager) DeviceRemove(u *uuid.UUID) error { func (d *DeviceManager) DeviceClear() error { // remove the directory and clear the cache devicePath := path.Join(d.databasePath, deviceDir) - candidates, err := ioutil.ReadDir(devicePath) + candidates, err := os.ReadDir(devicePath) if err != nil { return fmt.Errorf("unable to read device certificates at %s: %v", devicePath, err) } @@ -491,7 +493,7 @@ func (d *DeviceManager) DeviceGet(u *uuid.UUID) (*x509.Certificate, *x509.Certif return nil, nil, "", fmt.Errorf("error reading onboard certificate at %s: %v", certPath, err) } serialPath := path.Join(devicePath, deviceSerialFilename) - serial, err := ioutil.ReadFile(serialPath) + serial, err := os.ReadFile(serialPath) // we can accept not reading the onboard serial if err != nil && !os.IsNotExist(err) { return nil, nil, "", fmt.Errorf("error reading device serial at %s: %v", serialPath, err) @@ -621,7 +623,7 @@ func (d *DeviceManager) DeviceRegister(unew uuid.UUID, cert, onboard *x509.Certi } if serial != "" { serialPath := path.Join(devicePath, deviceSerialFilename) - err = ioutil.WriteFile(serialPath, []byte(serial), 0644) + err = os.WriteFile(serialPath, []byte(serial), 0644) if err != nil { return fmt.Errorf("error saving device serial to %s: %v", serialPath, err) } @@ -672,7 +674,7 @@ func (d *DeviceManager) OnboardRegister(cert *x509.Certificate, serial []string) } // serials file f = path.Join(onboardPath, onboardCertSerials) - err = ioutil.WriteFile(f, []byte(strings.Join(serial, "\n")), 0644) + err = os.WriteFile(f, []byte(strings.Join(serial, "\n")), 0644) if err != nil { return fmt.Errorf("unable to write onboard serials file %s: %v", f, err) } @@ -801,7 +803,7 @@ func (d *DeviceManager) WriteCerts(u uuid.UUID, b []byte) error { func (d *DeviceManager) GetCerts(u uuid.UUID) ([]byte, error) { // read the config from disk fullAttestPath := path.Join(d.getDevicePath(u), deviceAttestCertsFilename) - b, err := ioutil.ReadFile(fullAttestPath) + b, err := os.ReadFile(fullAttestPath) if err != nil { return nil, fmt.Errorf("could not read certificates from %s: %v", fullAttestPath, err) } @@ -835,7 +837,7 @@ func (d *DeviceManager) WriteStorageKeys(u uuid.UUID, b []byte) error { func (d *DeviceManager) GetStorageKeys(u uuid.UUID) ([]byte, error) { // read storage keys from disk fullStorageKeysPath := path.Join(d.getDevicePath(u), deviceStorageKeysFilename) - b, err := ioutil.ReadFile(fullStorageKeysPath) + b, err := os.ReadFile(fullStorageKeysPath) if err != nil { return nil, fmt.Errorf("could not read storage keys from %s: %v", fullStorageKeysPath, err) } @@ -847,7 +849,7 @@ func (d *DeviceManager) GetStorageKeys(u uuid.UUID) ([]byte, error) { func (d *DeviceManager) GetConfig(u uuid.UUID) ([]byte, error) { // read the config from disk fullConfigPath := path.Join(d.getDevicePath(u), deviceConfigFilename) - b, err := ioutil.ReadFile(fullConfigPath) + b, err := os.ReadFile(fullConfigPath) switch { case err != nil && os.IsNotExist(err): // create the base file if it does not exist @@ -907,7 +909,7 @@ func (d *DeviceManager) refreshCache() error { // scan the onboard path for all files which end in ".pem" and load them onboardPath := path.Join(d.databasePath, onboardDir) - candidates, err := ioutil.ReadDir(onboardPath) + candidates, err := os.ReadDir(onboardPath) if err != nil { return fmt.Errorf("unable to read onboarding certificates at %s: %v", onboardPath, err) } @@ -926,7 +928,7 @@ func (d *DeviceManager) refreshCache() error { } // read the file - b, err := ioutil.ReadFile(f) + b, err := os.ReadFile(f) if err != nil { return fmt.Errorf("unable to read onboard certificate file %s: %v", f, err) } @@ -947,7 +949,7 @@ func (d *DeviceManager) refreshCache() error { if err != nil { continue } - b, err = ioutil.ReadFile(f) + b, err = os.ReadFile(f) if err != nil { return fmt.Errorf("unable to read onboard serial file %s: %v", f, err) } @@ -960,7 +962,7 @@ func (d *DeviceManager) refreshCache() error { // scan the device path for each dir which is the UUID // and in each one, if a cert exists with the appropriate name, load it devicePath := path.Join(d.databasePath, deviceDir) - candidates, err = ioutil.ReadDir(devicePath) + candidates, err = os.ReadDir(devicePath) if err != nil { return fmt.Errorf("unable to read devices at %s: %v", devicePath, err) } @@ -986,7 +988,7 @@ func (d *DeviceManager) refreshCache() error { continue } // read the file - b, err := ioutil.ReadFile(f) + b, err := os.ReadFile(f) if err != nil { return fmt.Errorf("unable to read device certificate file %s: %v", f, err) } @@ -1010,7 +1012,7 @@ func (d *DeviceManager) refreshCache() error { continue } // read the file - b, err = ioutil.ReadFile(f) + b, err = os.ReadFile(f) if err != nil { return fmt.Errorf("unable to read device onboard certificate file %s: %v", f, err) } @@ -1020,10 +1022,6 @@ func (d *DeviceManager) refreshCache() error { if err != nil { return fmt.Errorf("unable to convert data from file %s to device onboard certificate: %v", f, err) } - certStr = string(cert.Raw) - if err != nil { - return fmt.Errorf("unable to convert device uuid from directory name %s: %v", name, err) - } devItem := d.devices[u] devItem.Onboard = cert d.devices[u] = devItem @@ -1035,7 +1033,7 @@ func (d *DeviceManager) refreshCache() error { continue } // read the file - b, err = ioutil.ReadFile(f) + b, err = os.ReadFile(f) if err != nil { return fmt.Errorf("unable to read device serial file %s: %v", f, err) } @@ -1122,8 +1120,7 @@ func (d *DeviceManager) deviceExists(u uuid.UUID) bool { // checkValidOnboardSerial see if a particular certificate+serial combinaton is valid // does **not** check if it has been used func (d *DeviceManager) checkValidOnboardSerial(cert *x509.Certificate, serial string) error { - certStr := string(cert.Raw) - if c, ok := d.onboardCerts[certStr]; ok { + if c, ok := d.onboardCerts[string(cert.Raw)]; ok { // accept the specific serial or the wildcard if _, ok := c[serial]; ok { return nil @@ -1236,7 +1233,7 @@ func (d *DeviceManager) GetDeviceOptions(u uuid.UUID) ([]byte, error) { } // read options from disk fullOptionsPath := path.Join(d.getDevicePath(u), deviceOptionsFilename) - b, err := ioutil.ReadFile(fullOptionsPath) + b, err := os.ReadFile(fullOptionsPath) if err != nil { // if error another than not exists than return if !os.IsNotExist(err) { @@ -1254,9 +1251,9 @@ func (d *DeviceManager) GetDeviceOptions(u uuid.UUID) ([]byte, error) { } func (d *DeviceManager) SetGlobalOptions(b []byte) error { - return ioutil.WriteFile(filepath.Join(d.databasePath, globalOptionsFilename), b, 0666) + return os.WriteFile(filepath.Join(d.databasePath, globalOptionsFilename), b, 0666) } func (d *DeviceManager) GetGlobalOptions() ([]byte, error) { - return ioutil.ReadFile(filepath.Join(d.databasePath, globalOptionsFilename)) + return os.ReadFile(filepath.Join(d.databasePath, globalOptionsFilename)) } diff --git a/pkg/driver/file/device_manager_file_test.go b/pkg/driver/file/device_manager_file_test.go index 434e0c3..414fc8b 100644 --- a/pkg/driver/file/device_manager_file_test.go +++ b/pkg/driver/file/device_manager_file_test.go @@ -7,7 +7,6 @@ import ( "bytes" "crypto/x509" "fmt" - "io/ioutil" "os" "path" "strings" @@ -54,7 +53,7 @@ func TestDeviceManager(t *testing.T) { if err != nil { t.Fatalf("Unable to write certificate: %v", err) } - ioutil.WriteFile(path.Join(onboardPath, onboardCertSerials), []byte(strings.Join(serials, "\n")), 0644) + os.WriteFile(path.Join(onboardPath, onboardCertSerials), []byte(strings.Join(serials, "\n")), 0644) if err != nil { t.Fatalf("Unable to write serials: %v", err) } @@ -107,7 +106,7 @@ func TestDeviceManager(t *testing.T) { timeout := 5 // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -157,7 +156,7 @@ func TestDeviceManager(t *testing.T) { // include the device onboarding cert copyFile(path.Join(onboardPath, onboardCertFilename), path.Join(devicePath, DeviceOnboardFilename)) // write the device serial - ioutil.WriteFile(path.Join(devicePath, deviceSerialFilename), []byte(serial), 0644) + os.WriteFile(path.Join(devicePath, deviceSerialFilename), []byte(serial), 0644) // wait for the timeout time.Sleep(time.Duration(timeout) * time.Millisecond) // force the cache to refresh @@ -208,7 +207,7 @@ func TestDeviceManager(t *testing.T) { } for i, tt := range tests { // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -231,7 +230,7 @@ func TestDeviceManager(t *testing.T) { if err != nil { t.Fatalf("Unable to write certificate: %v", err) } - ioutil.WriteFile(path.Join(onboardPath, onboardCertSerials), []byte(strings.Join(tt.serials, "\n")), 0644) + os.WriteFile(path.Join(onboardPath, onboardCertSerials), []byte(strings.Join(tt.serials, "\n")), 0644) if err != nil { t.Fatalf("Unable to write serials: %v", err) } @@ -242,7 +241,7 @@ func TestDeviceManager(t *testing.T) { t.Errorf("%d: mismatched errors, actual %v expected %v", i, err, tt.err) case err == nil && !common.EqualStringSlice(serial, tt.serials): t.Errorf("%d: mismatched serials, actual '%v', expected '%v'", i, serial, tt.serials) - case err == nil && bytes.Compare(validCert.Raw, cert.Raw) != 0: + case err == nil && !bytes.Equal(validCert.Raw, cert.Raw): t.Errorf("%d: mismatched certs", i) } } @@ -250,7 +249,7 @@ func TestDeviceManager(t *testing.T) { t.Run("TestOnboardList", func(t *testing.T) { // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -283,7 +282,7 @@ func TestDeviceManager(t *testing.T) { for i, tt := range tests { // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -319,7 +318,7 @@ func TestDeviceManager(t *testing.T) { t.Run("TestOnboardClear", func(t *testing.T) { // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -330,10 +329,10 @@ func TestDeviceManager(t *testing.T) { fillOnboard(&dm) - err = dm.OnboardClear() + dm.OnboardClear() // read the dirs onboardPath := path.Join(dm.databasePath, onboardDir) - candidates, err := ioutil.ReadDir(onboardPath) + candidates, err := os.ReadDir(onboardPath) switch { case err != nil: t.Errorf("unexpected error: %v", err) @@ -358,7 +357,7 @@ func TestDeviceManager(t *testing.T) { } for i, tt := range tests { // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -399,7 +398,7 @@ func TestDeviceManager(t *testing.T) { t.Run("TestDeviceClear", func(t *testing.T) { // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -410,10 +409,10 @@ func TestDeviceManager(t *testing.T) { fillDevice(&dm) - err = dm.DeviceClear() + dm.DeviceClear() // read the dirs devicePath := path.Join(dm.databasePath, deviceDir) - candidates, err := ioutil.ReadDir(devicePath) + candidates, err := os.ReadDir(devicePath) switch { case err != nil: t.Errorf("unexpected error: %v", err) @@ -434,7 +433,7 @@ func TestDeviceManager(t *testing.T) { } for i, tt := range tests { // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -465,7 +464,7 @@ func TestDeviceManager(t *testing.T) { switch { case (err != nil && tt.err == nil) || (err == nil && tt.err != nil) || (err != nil && tt.err != nil && !strings.HasPrefix(err.Error(), tt.err.Error())): t.Errorf("%d: mismatched errors, actual %v expected %v", i, err, tt.err) - case err == nil && cert != nil && fileCert != nil && bytes.Compare(fileCert.Raw, cert.Raw) != 0: + case err == nil && cert != nil && fileCert != nil && !bytes.Equal(fileCert.Raw, cert.Raw): t.Errorf("%d: mismatched cert", i) } } @@ -473,7 +472,7 @@ func TestDeviceManager(t *testing.T) { t.Run("TestDeviceList", func(t *testing.T) { // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -508,7 +507,7 @@ func TestDeviceManager(t *testing.T) { for i, tt := range tests { ts := int64(1000) // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -527,7 +526,7 @@ func TestDeviceManager(t *testing.T) { case err == nil && tt.err == nil && tt.validMsg: // check if the correct file exists // only check if errors were nil, and we had a validMsg; nothing to write otherwise - fi, err := ioutil.ReadDir(sectionPath) + fi, err := os.ReadDir(sectionPath) switch { case err != nil: t.Errorf("missing directory: %s", sectionPath) @@ -622,7 +621,7 @@ func TestDeviceManager(t *testing.T) { ) // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -696,7 +695,7 @@ func TestDeviceManager(t *testing.T) { // reset with each test // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -753,7 +752,7 @@ func TestDeviceManager(t *testing.T) { for _, tt := range tests { // reset with each test // make a temporary directory with which to work - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -792,12 +791,12 @@ func TestDeviceManager(t *testing.T) { } func copyFile(src, dest string) error { - input, err := ioutil.ReadFile(src) + input, err := os.ReadFile(src) if err != nil { return err } - err = ioutil.WriteFile(dest, input, 0644) + err = os.WriteFile(dest, input, 0644) if err != nil { return err } @@ -835,7 +834,7 @@ func checkDeviceDirectory(devicePath string, unew uuid.UUID, deviceCert, onboard if _, err := os.Stat(deviceSerialPath); err != nil && os.IsNotExist(err) { return fmt.Errorf("device serials file %s does not exist", deviceSerialPath) } - if b, err = ioutil.ReadFile(deviceSerialPath); err != nil { + if b, err = os.ReadFile(deviceSerialPath); err != nil { return fmt.Errorf("error reading certificate file %s: %v", deviceSerialPath, err) } if string(b) != serial { @@ -858,7 +857,7 @@ func saveOnboardCertAndSerials(onboardDir string, cert *x509.Certificate, serial return fmt.Errorf("error writing onboard certificate: %v", err) } - err = ioutil.WriteFile(path.Join(onboardPath, onboardCertSerials), []byte(strings.Join(serials, "\n")), 0644) + err = os.WriteFile(path.Join(onboardPath, onboardCertSerials), []byte(strings.Join(serials, "\n")), 0644) if err != nil { return fmt.Errorf("error writing onboard serials: %v", err) } diff --git a/pkg/driver/file/dirreader_test.go b/pkg/driver/file/dirreader_test.go index a59d283..744c861 100644 --- a/pkg/driver/file/dirreader_test.go +++ b/pkg/driver/file/dirreader_test.go @@ -3,7 +3,6 @@ package file_test import ( "bytes" "io" - "io/ioutil" "os" "path" "strings" @@ -31,7 +30,7 @@ func TestDirReader(t *testing.T) { } }) t.Run("empty directory", func(t *testing.T) { - dir, err := ioutil.TempDir("", "dirreader_test") + dir, err := os.MkdirTemp("", "dirreader_test") if err != nil { t.Fatalf("failure to create temporary directory: %v", err) } @@ -45,13 +44,13 @@ func TestDirReader(t *testing.T) { } }) t.Run("single small file", func(t *testing.T) { - dir, err := ioutil.TempDir("", "dirreader_test") + dir, err := os.MkdirTemp("", "dirreader_test") if err != nil { t.Fatalf("failure to create temporary directory: %v", err) } // create the file that is smaller than our buffer data := []byte("Really small file") - if err := ioutil.WriteFile(path.Join(dir, "A"), data, 0644); err != nil { + if err := os.WriteFile(path.Join(dir, "A"), data, 0644); err != nil { t.Fatalf("failure to write temporary file: %v", err) } defer os.RemoveAll(dir) @@ -85,12 +84,12 @@ func TestDirReader(t *testing.T) { }) t.Run("single large file", func(t *testing.T) { data := []byte("Really large file with lots of data bigger than our buffer") - dir, err := ioutil.TempDir("", "dirreader_test") + dir, err := os.MkdirTemp("", "dirreader_test") if err != nil { t.Fatalf("failure to create temporary directory: %v", err) } // create the file that is larger than our buffer - if err := ioutil.WriteFile(path.Join(dir, "A"), data, 0644); err != nil { + if err := os.WriteFile(path.Join(dir, "A"), data, 0644); err != nil { t.Fatalf("failure to write temporary file: %v", err) } defer os.RemoveAll(dir) @@ -144,15 +143,15 @@ func TestDirReader(t *testing.T) { data1 := []byte("Really large file with lots of data bigger than our buffer") data2 := []byte("yet another file with lots of gibberish data to put in") - dir, err := ioutil.TempDir("", "dirreader_test") + dir, err := os.MkdirTemp("", "dirreader_test") if err != nil { t.Fatalf("failure to create temporary directory: %v", err) } // create the multiple files that together are larger than our buffer - if err := ioutil.WriteFile(path.Join(dir, "A"), data1, 0644); err != nil { + if err := os.WriteFile(path.Join(dir, "A"), data1, 0644); err != nil { t.Fatalf("failure to write temporary file: %v", err) } - if err := ioutil.WriteFile(path.Join(dir, "B"), data2, 0644); err != nil { + if err := os.WriteFile(path.Join(dir, "B"), data2, 0644); err != nil { t.Fatalf("failure to write temporary file: %v", err) } diff --git a/pkg/server/adminHandler.go b/pkg/server/adminHandler.go index e3f6e7a..1f620a0 100644 --- a/pkg/server/adminHandler.go +++ b/pkg/server/adminHandler.go @@ -8,7 +8,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "log" "net/http" "strconv" @@ -17,6 +16,7 @@ import ( "github.com/gorilla/mux" "github.com/lf-edge/adam/pkg/driver" "github.com/lf-edge/adam/pkg/driver/common" + "github.com/lf-edge/adam/pkg/util" ax "github.com/lf-edge/adam/pkg/x509" "github.com/lf-edge/eve/api/go/config" "github.com/lf-edge/eve/api/go/info" @@ -35,6 +35,8 @@ type adminHandler struct { logChannel chan []byte infoChannel chan []byte requestsChannel chan []byte + signingCertPath *string + signingKeyPath *string } // OnboardCert encoding for sending an onboard cert and serials via json @@ -325,7 +327,7 @@ func (h *adminHandler) deviceConfigSet(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, "bad UUID", http.StatusBadRequest) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("bad body: %v", err), http.StatusBadRequest) return @@ -586,7 +588,7 @@ func (h *adminHandler) deviceOptionsSet(w http.ResponseWriter, r *http.Request) http.Error(w, err.Error(), http.StatusBadRequest) return } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("bad body: %v", err), http.StatusBadRequest) return @@ -626,7 +628,7 @@ func (h *adminHandler) globalOptionsGet(w http.ResponseWriter, _ *http.Request) } func (h *adminHandler) globalOptionsSet(w http.ResponseWriter, r *http.Request) { - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("bad body: %v", err), http.StatusBadRequest) return @@ -646,3 +648,41 @@ func (h *adminHandler) globalOptionsSet(w http.ResponseWriter, r *http.Request) } w.WriteHeader(http.StatusOK) } + +func (h *adminHandler) signingCertSet(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("bad body: %v", err), http.StatusBadRequest) + return + } + if len(body) == 0 { + http.Error(w, "empty body", http.StatusBadRequest) + return + } + + // validate the cert + _, err = ax.ParseCert(body) + if err != nil { + log.Printf("signingCertSet: %s", err) + http.Error(w, fmt.Sprintf("failed to parse signing cert: %s", err), http.StatusInternalServerError) + return + } + + certPath := *h.signingCertPath + // when no signing cert is set, we will create a new file + if certPath == "" { + // generate a random name + certPath = fmt.Sprintf("/tmp/signingCert-%s.pem", common.RandomString(8)) + } + // write the cert to the file atomically (using a temp file) + err = util.WriteRename(certPath, body) + if err != nil { + log.Printf("signingCertSet: %s", err) + http.Error(w, fmt.Sprintf("failed to write signing cert: %s", err), http.StatusInternalServerError) + return + } + // set the signing cert path + *h.signingCertPath = certPath + + w.WriteHeader(http.StatusOK) +} diff --git a/pkg/server/apiHandler.go b/pkg/server/apiHandler.go index 59735bd..f9fe244 100644 --- a/pkg/server/apiHandler.go +++ b/pkg/server/apiHandler.go @@ -5,7 +5,7 @@ package server import ( "encoding/json" - "io/ioutil" + "io" "log" "net/http" "time" @@ -77,7 +77,7 @@ func (h *apiHandler) register(w http.ResponseWriter, r *http.Request) { // - get the serial // - get the device cert onboardCert := getClientCert(r) - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil { log.Printf("error reading request body: %v", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) @@ -158,7 +158,7 @@ func (h *apiHandler) info(w http.ResponseWriter, r *http.Request) { if u == nil { return } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil || len(b) == 0 { log.Printf("error reading request body: %v", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) @@ -178,7 +178,7 @@ func (h *apiHandler) metrics(w http.ResponseWriter, r *http.Request) { if u == nil { return } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil { log.Printf("error reading request body: %v", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) @@ -198,7 +198,7 @@ func (h *apiHandler) logs(w http.ResponseWriter, r *http.Request) { if u == nil { return } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil || len(b) == 0 { log.Printf("error reading request body: %v", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) @@ -239,7 +239,7 @@ func (h *apiHandler) appLogs(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil || len(b) == 0 { log.Printf("error reading request body: %v", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) @@ -291,7 +291,7 @@ func (h *apiHandler) newAppLogs(w http.ResponseWriter, r *http.Request) { // retrieve the config request func (h *apiHandler) getClientConfigRequest(r *http.Request) (*config.ConfigRequest, error) { - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { log.Printf("Body read failed: %v", err) return nil, err @@ -310,7 +310,7 @@ func (h *apiHandler) flowLog(w http.ResponseWriter, r *http.Request) { if u == nil { return } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil || len(b) == 0 { log.Printf("error reading request body: %v", err) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) diff --git a/pkg/server/apiHandlerv2.go b/pkg/server/apiHandlerv2.go index 2cbeabf..fc759c8 100644 --- a/pkg/server/apiHandlerv2.go +++ b/pkg/server/apiHandlerv2.go @@ -17,7 +17,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "log" "math/big" "net/http" @@ -43,10 +42,10 @@ type apiHandlerv2 struct { logChannel chan []byte infoChannel chan []byte metricChannel chan []byte - signingCertPath string - signingKeyPath string - encryptCertPath string - encryptKeyPath string + signingCertPath *string + signingKeyPath *string + encryptCertPath *string + encryptKeyPath *string } const ( @@ -83,13 +82,13 @@ func (h *apiHandlerv2) recordClient(u *uuid.UUID, r *http.Request) { h.manager.WriteRequest(*u, b) } -//validateAuthContainerAndRecord processes http.Request extracts AuthContainer and do its validation -//against registered devices: +// validateAuthContainerAndRecord processes http.Request extracts AuthContainer and do its validation +// against registered devices: // checks for certs hash in AuthContainer and go through saved certs to check auth state // it verifies Signature of AuthContainer payload against saved cert // returns ProtectedPayload and device uuid func (h *apiHandlerv2) validateAuthContainerAndRecord(w http.ResponseWriter, r *http.Request) ([]byte, *uuid.UUID) { - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil || len(b) == 0 { log.Printf("error reading request body: %v", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -130,10 +129,10 @@ func (h *apiHandlerv2) validateAuthContainerAndRecord(w http.ResponseWriter, r * return payload, u } -//getAllCerts process certificates files and return structure with them +// getAllCerts process certificates files and return structure with them func (h *apiHandlerv2) getAllCerts() (map[string]*certs.ZCert, error) { allCerts := make(map[string]*certs.ZCert) - signingCerts, sgErr := getCertChain(h.signingCertPath, certs.ZCertType_CERT_TYPE_CONTROLLER_SIGNING) + signingCerts, sgErr := getCertChain(*h.signingCertPath, certs.ZCertType_CERT_TYPE_CONTROLLER_SIGNING) if sgErr != nil { return nil, fmt.Errorf("error occurred while fetching signing cert chain: %v", sgErr) } @@ -143,7 +142,7 @@ func (h *apiHandlerv2) getAllCerts() (map[string]*certs.ZCert, error) { allCerts[string(cert.CertHash)] = cert } - encryptCerts, egErr := getCertChain(h.encryptCertPath, certs.ZCertType_CERT_TYPE_CONTROLLER_ECDH_EXCHANGE) + encryptCerts, egErr := getCertChain(*h.encryptCertPath, certs.ZCertType_CERT_TYPE_CONTROLLER_ECDH_EXCHANGE) if egErr != nil { return nil, fmt.Errorf("error occurred while fetching encryption cert chain: %v", egErr) } @@ -206,7 +205,7 @@ func getCertChain(certPath string, certType certs.ZCertType) (*common.Zcerts, er return nil, err } - certData, err := ioutil.ReadFile(certPath) + certData, err := os.ReadFile(certPath) if err != nil { return nil, err } @@ -362,7 +361,7 @@ func (h *apiHandlerv2) prepareEnvelope(payload []byte) ([]byte, error) { var senderCertHash []byte var signingCert []byte - zcerts, gErr := getCertChain(h.signingCertPath, certs.ZCertType_CERT_TYPE_CONTROLLER_SIGNING) + zcerts, gErr := getCertChain(*h.signingCertPath, certs.ZCertType_CERT_TYPE_CONTROLLER_SIGNING) if gErr != nil { return nil, gErr } @@ -374,7 +373,7 @@ func (h *apiHandlerv2) prepareEnvelope(payload []byte) ([]byte, error) { } //read private signing key. - signingPrivateKey, rErr := ioutil.ReadFile(h.signingKeyPath) + signingPrivateKey, rErr := os.ReadFile(*h.signingKeyPath) if rErr != nil { return nil, fmt.Errorf("error occurred while reading signing key: %v", rErr) } @@ -399,7 +398,7 @@ func (h *apiHandlerv2) prepareEnvelope(payload []byte) ([]byte, error) { } func (h *apiHandlerv2) processAuthContainer(reader io.Reader) (*auth.AuthContainer, error) { - b, err := ioutil.ReadAll(reader) + b, err := io.ReadAll(reader) if err != nil || len(b) == 0 { return nil, fmt.Errorf("error reading request body: %v", err) } diff --git a/pkg/server/server.go b/pkg/server/server.go index eb0c064..bec3b7c 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -8,8 +8,8 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "io/fs" - "io/ioutil" "log" "net/http" "os" @@ -120,10 +120,10 @@ func (s *Server) Start() { logChannel: logChannel, infoChannel: infoChannel, metricChannel: metricChannel, - signingCertPath: s.SigningCertPath, - signingKeyPath: s.SigningKeyPath, - encryptCertPath: s.EncryptCertPath, - encryptKeyPath: s.EncryptKeyPath, + signingCertPath: &s.SigningCertPath, + signingKeyPath: &s.SigningKeyPath, + encryptCertPath: &s.EncryptCertPath, + encryptKeyPath: &s.EncryptKeyPath, } edv2 := router.PathPrefix("/api/v2/edgedevice").Subrouter() @@ -147,9 +147,11 @@ func (s *Server) Start() { // admin endpoint - custom, used to manage adam admin := &adminHandler{ - manager: s.DeviceManager, - logChannel: logChannel, - infoChannel: infoChannel, + manager: s.DeviceManager, + logChannel: logChannel, + infoChannel: infoChannel, + signingCertPath: &s.SigningCertPath, + signingKeyPath: &s.SigningKeyPath, } ad := router.PathPrefix("/admin").Subrouter() @@ -188,6 +190,7 @@ func (s *Server) Start() { ad.HandleFunc("/device/{uuid}/options", admin.deviceOptionsSet).Methods("PUT") ad.HandleFunc("/options", admin.globalOptionsGet).Methods("GET") ad.HandleFunc("/options", admin.globalOptionsSet).Methods("PUT") + ad.HandleFunc("/certs", admin.signingCertSet).Methods("POST") var ( //index []byte @@ -212,7 +215,7 @@ func (s *Server) Start() { return } defer f.Close() - content, err := ioutil.ReadAll(f) + content, err := io.ReadAll(f) if err != nil { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return diff --git a/pkg/util/file.go b/pkg/util/file.go new file mode 100644 index 0000000..15226de --- /dev/null +++ b/pkg/util/file.go @@ -0,0 +1,67 @@ +package util + +import ( + "errors" + "fmt" + "os" + "path/filepath" +) + +func WriteRename(fileName string, b []byte) error { + dirName := filepath.Dir(fileName) + // Do atomic rename to avoid partially written files + tmpfile, err := os.CreateTemp(dirName, "tmp") + if err != nil { + errStr := fmt.Sprintf("WriteRename(%s): %s", + fileName, err) + return errors.New(errStr) + } + defer tmpfile.Close() + defer os.Remove(tmpfile.Name()) + _, err = tmpfile.Write(b) + if err != nil { + errStr := fmt.Sprintf("WriteRename(%s): %s", + fileName, err) + return errors.New(errStr) + } + // Make sure the file is flushed from buffers onto the disk + if err := tmpfile.Sync(); err != nil { + errStr := fmt.Sprintf("WriteRename(%s) failed to sync temp file: %s", + fileName, err) + return errors.New(errStr) + } + + if err := tmpfile.Close(); err != nil { + errStr := fmt.Sprintf("WriteRename(%s): %s", + fileName, err) + return errors.New(errStr) + } + + if err := os.Rename(tmpfile.Name(), fileName); err != nil { + errStr := fmt.Sprintf("writeRename(%s): %s", + fileName, err) + return errors.New(errStr) + } + + return DirSync(filepath.Dir(fileName)) +} + +// DirSync flushes changes made to a directory. +func DirSync(dirName string) error { + f, err := os.OpenFile(dirName, os.O_RDONLY, 0755) + if err != nil { + return err + } + + err = f.Sync() + if err != nil { + f.Close() + return err + } + + // Not a deferred call, because DirSync is a critical + // path. Better safe then sorry, and we better check all the + // errors including one returned by close() + err = f.Close() + return err +} diff --git a/pkg/x509/generate_test.go b/pkg/x509/generate_test.go index 8df4d86..e11b857 100644 --- a/pkg/x509/generate_test.go +++ b/pkg/x509/generate_test.go @@ -5,7 +5,6 @@ package x509_test import ( "fmt" - "io/ioutil" "os" "path" "regexp" @@ -14,6 +13,7 @@ import ( "testing" "crypto/x509" + ax "github.com/lf-edge/adam/pkg/x509" ) @@ -85,7 +85,7 @@ func TestGenerate(t *testing.T) { } func TestGenerateAndWrite(t *testing.T) { - dir, err := ioutil.TempDir("", "adam-test") + dir, err := os.MkdirTemp("", "adam-test") if err != nil { t.Fatal(err) } @@ -122,12 +122,12 @@ func TestGenerateAndWrite(t *testing.T) { keyPath = path.Join(dir, keyFilename) } if tt.certExists { - ioutil.WriteFile(certPath, []byte{1, 2, 3}, 0644) + os.WriteFile(certPath, []byte{1, 2, 3}, 0644) } else { os.Remove(certPath) } if tt.keyExists { - ioutil.WriteFile(keyPath, []byte{1, 2, 3}, 0644) + os.WriteFile(keyPath, []byte{1, 2, 3}, 0644) } else { os.Remove(keyPath) } diff --git a/pkg/x509/io.go b/pkg/x509/io.go index 373fea6..95611eb 100644 --- a/pkg/x509/io.go +++ b/pkg/x509/io.go @@ -8,7 +8,6 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "io/ioutil" "os" ) @@ -22,7 +21,7 @@ func WriteCert(cert []byte, certPath string, force bool) error { return fmt.Errorf("file already exists at certPath %s", certPath) } certPem := PemEncodeCert(cert) - err := ioutil.WriteFile(certPath, certPem, 0644) + err := os.WriteFile(certPath, certPem, 0644) if err != nil { return fmt.Errorf("failed to write certificate to %s: %v", certPath, err) } @@ -40,7 +39,7 @@ func WriteKey(key *rsa.PrivateKey, keyPath string, force bool) error { return fmt.Errorf("file already exists at keyPath %s", keyPath) } keyPem := PemEncodeKey(key) - err := ioutil.WriteFile(keyPath, keyPem, 0600) + err := os.WriteFile(keyPath, keyPem, 0600) if err != nil { return fmt.Errorf("failed to write key to %s: %v", keyPath, err) } @@ -57,7 +56,7 @@ func ReadCert(p string) (*x509.Certificate, error) { if _, err = os.Stat(p); err != nil && os.IsNotExist(err) { return nil, err } - if b, err = ioutil.ReadFile(p); err != nil { + if b, err = os.ReadFile(p); err != nil { return nil, fmt.Errorf("error reading certificate file %s: %v", p, err) } return ParseCert(b)