|  | @@ -6,12 +6,11 @@ package models
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  import (
 |  |  import (
 | 
											
												
													
														|  |  	"bufio"
 |  |  	"bufio"
 | 
											
												
													
														|  | 
 |  | +	"bytes"
 | 
											
												
													
														|  |  	"errors"
 |  |  	"errors"
 | 
											
												
													
														|  |  	"fmt"
 |  |  	"fmt"
 | 
											
												
													
														|  | -	"io"
 |  | 
 | 
											
												
													
														|  |  	"io/ioutil"
 |  |  	"io/ioutil"
 | 
											
												
													
														|  |  	"os"
 |  |  	"os"
 | 
											
												
													
														|  | -	"os/exec"
 |  | 
 | 
											
												
													
														|  |  	"path"
 |  |  	"path"
 | 
											
												
													
														|  |  	"path/filepath"
 |  |  	"path/filepath"
 | 
											
												
													
														|  |  	"strings"
 |  |  	"strings"
 | 
											
										
											
												
													
														|  | @@ -19,7 +18,9 @@ import (
 | 
											
												
													
														|  |  	"time"
 |  |  	"time"
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	"github.com/Unknwon/com"
 |  |  	"github.com/Unknwon/com"
 | 
											
												
													
														|  | 
 |  | +	qlog "github.com/qiniu/log"
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +	"github.com/gogits/gogs/modules/base"
 | 
											
												
													
														|  |  	"github.com/gogits/gogs/modules/log"
 |  |  	"github.com/gogits/gogs/modules/log"
 | 
											
												
													
														|  |  )
 |  |  )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -30,29 +31,21 @@ const (
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  var (
 |  |  var (
 | 
											
												
													
														|  |  	ErrKeyAlreadyExist = errors.New("Public key already exist")
 |  |  	ErrKeyAlreadyExist = errors.New("Public key already exist")
 | 
											
												
													
														|  | 
 |  | +	ErrKeyNotExist     = errors.New("Public key does not exist")
 | 
											
												
													
														|  |  )
 |  |  )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  var sshOpLocker = sync.Mutex{}
 |  |  var sshOpLocker = sync.Mutex{}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  var (
 |  |  var (
 | 
											
												
													
														|  | -	sshPath string
 |  | 
 | 
											
												
													
														|  | -	appPath string
 |  | 
 | 
											
												
													
														|  | 
 |  | +	sshPath string // SSH directory.
 | 
											
												
													
														|  | 
 |  | +	appPath string // Execution(binary) path.
 | 
											
												
													
														|  |  )
 |  |  )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -// exePath returns the executable path.
 |  | 
 | 
											
												
													
														|  | -func exePath() (string, error) {
 |  | 
 | 
											
												
													
														|  | -	file, err := exec.LookPath(os.Args[0])
 |  | 
 | 
											
												
													
														|  | -	if err != nil {
 |  | 
 | 
											
												
													
														|  | -		return "", err
 |  | 
 | 
											
												
													
														|  | -	}
 |  | 
 | 
											
												
													
														|  | -	return filepath.Abs(file)
 |  | 
 | 
											
												
													
														|  | -}
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  |  // homeDir returns the home directory of current user.
 |  |  // homeDir returns the home directory of current user.
 | 
											
												
													
														|  |  func homeDir() string {
 |  |  func homeDir() string {
 | 
											
												
													
														|  |  	home, err := com.HomeDir()
 |  |  	home, err := com.HomeDir()
 | 
											
												
													
														|  |  	if err != nil {
 |  |  	if err != nil {
 | 
											
												
													
														|  | -		return "/"
 |  | 
 | 
											
												
													
														|  | 
 |  | +		qlog.Fatalln(err)
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  	return home
 |  |  	return home
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
										
											
												
													
														|  | @@ -60,17 +53,14 @@ func homeDir() string {
 | 
											
												
													
														|  |  func init() {
 |  |  func init() {
 | 
											
												
													
														|  |  	var err error
 |  |  	var err error
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	appPath, err = exePath()
 |  | 
 | 
											
												
													
														|  | -	if err != nil {
 |  | 
 | 
											
												
													
														|  | -		fmt.Printf("publickey.init(fail to get app path): %v\n", err)
 |  | 
 | 
											
												
													
														|  | -		os.Exit(2)
 |  | 
 | 
											
												
													
														|  | 
 |  | +	if appPath, err = base.ExecDir(); err != nil {
 | 
											
												
													
														|  | 
 |  | +		qlog.Fatalf("publickey.init(fail to get app path): %v\n", err)
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	// Determine and create .ssh path.
 |  |  	// Determine and create .ssh path.
 | 
											
												
													
														|  |  	sshPath = filepath.Join(homeDir(), ".ssh")
 |  |  	sshPath = filepath.Join(homeDir(), ".ssh")
 | 
											
												
													
														|  |  	if err = os.MkdirAll(sshPath, os.ModePerm); err != nil {
 |  |  	if err = os.MkdirAll(sshPath, os.ModePerm); err != nil {
 | 
											
												
													
														|  | -		fmt.Printf("publickey.init(fail to create sshPath(%s)): %v\n", sshPath, err)
 |  | 
 | 
											
												
													
														|  | -		os.Exit(2)
 |  | 
 | 
											
												
													
														|  | 
 |  | +		qlog.Fatalf("publickey.init(fail to create sshPath(%s)): %v\n", sshPath, err)
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -129,8 +119,8 @@ func AddPublicKey(key *PublicKey) (err error) {
 | 
											
												
													
														|  |  	return nil
 |  |  	return nil
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +// rewriteAuthorizedKeys finds and deletes corresponding line in authorized_keys file.
 | 
											
												
													
														|  |  func rewriteAuthorizedKeys(key *PublicKey, p, tmpP string) error {
 |  |  func rewriteAuthorizedKeys(key *PublicKey, p, tmpP string) error {
 | 
											
												
													
														|  | -	// Delete SSH key in SSH key file.
 |  | 
 | 
											
												
													
														|  |  	sshOpLocker.Lock()
 |  |  	sshOpLocker.Lock()
 | 
											
												
													
														|  |  	defer sshOpLocker.Unlock()
 |  |  	defer sshOpLocker.Unlock()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -146,55 +136,48 @@ func rewriteAuthorizedKeys(key *PublicKey, p, tmpP string) error {
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  	defer fw.Close()
 |  |  	defer fw.Close()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -	buf := bufio.NewReader(fr)
 |  | 
 | 
											
												
													
														|  | -	for {
 |  | 
 | 
											
												
													
														|  | -		line, errRead := buf.ReadString('\n')
 |  | 
 | 
											
												
													
														|  | -		line = strings.TrimSpace(line)
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -		if errRead != nil {
 |  | 
 | 
											
												
													
														|  | -			if errRead != io.EOF {
 |  | 
 | 
											
												
													
														|  | -				return errRead
 |  | 
 | 
											
												
													
														|  | -			}
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -			// Reached end of file, if nothing to read then break,
 |  | 
 | 
											
												
													
														|  | -			// otherwise handle the last line.
 |  | 
 | 
											
												
													
														|  | -			if len(line) == 0 {
 |  | 
 | 
											
												
													
														|  | -				break
 |  | 
 | 
											
												
													
														|  | -			}
 |  | 
 | 
											
												
													
														|  | 
 |  | +	isFound := false
 | 
											
												
													
														|  | 
 |  | +	keyword := []byte(fmt.Sprintf("key-%d", key.Id))
 | 
											
												
													
														|  | 
 |  | +	content := []byte(key.Content)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	snr := bufio.NewScanner(fr)
 | 
											
												
													
														|  | 
 |  | +	for snr.Scan() {
 | 
											
												
													
														|  | 
 |  | +		line := append(bytes.TrimSpace(snr.Bytes()), '\n')
 | 
											
												
													
														|  | 
 |  | +		if len(line) == 0 {
 | 
											
												
													
														|  | 
 |  | +			continue
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  		// Found the line and copy rest of file.
 |  |  		// Found the line and copy rest of file.
 | 
											
												
													
														|  | -		if strings.Contains(line, fmt.Sprintf("key-%d", key.Id)) && strings.Contains(line, key.Content) {
 |  | 
 | 
											
												
													
														|  | 
 |  | +		if !isFound && bytes.Contains(line, keyword) && bytes.Contains(line, content) {
 | 
											
												
													
														|  | 
 |  | +			isFound = true
 | 
											
												
													
														|  |  			continue
 |  |  			continue
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  		// Still finding the line, copy the line that currently read.
 |  |  		// Still finding the line, copy the line that currently read.
 | 
											
												
													
														|  | -		if _, err = fw.WriteString(line + "\n"); err != nil {
 |  | 
 | 
											
												
													
														|  | 
 |  | +		if _, err = fw.Write(line); err != nil {
 | 
											
												
													
														|  |  			return err
 |  |  			return err
 | 
											
												
													
														|  |  		}
 |  |  		}
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -		if errRead == io.EOF {
 |  | 
 | 
											
												
													
														|  | -			break
 |  | 
 | 
											
												
													
														|  | -		}
 |  | 
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  	return nil
 |  |  	return nil
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  // DeletePublicKey deletes SSH key information both in database and authorized_keys file.
 |  |  // DeletePublicKey deletes SSH key information both in database and authorized_keys file.
 | 
											
												
													
														|  | -func DeletePublicKey(key *PublicKey) (err error) {
 |  | 
 | 
											
												
													
														|  | -	// Delete SSH key in database.
 |  | 
 | 
											
												
													
														|  | -	has, err := orm.Id(key.Id).Get(key)
 |  | 
 | 
											
												
													
														|  | 
 |  | +func DeletePublicKey(key *PublicKey) error {
 | 
											
												
													
														|  | 
 |  | +	has, err := orm.Get(key)
 | 
											
												
													
														|  |  	if err != nil {
 |  |  	if err != nil {
 | 
											
												
													
														|  |  		return err
 |  |  		return err
 | 
											
												
													
														|  |  	} else if !has {
 |  |  	} else if !has {
 | 
											
												
													
														|  | -		return errors.New("Public key does not exist")
 |  | 
 | 
											
												
													
														|  | 
 |  | +		return ErrKeyNotExist
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  	if _, err = orm.Delete(key); err != nil {
 |  |  	if _, err = orm.Delete(key); err != nil {
 | 
											
												
													
														|  |  		return err
 |  |  		return err
 | 
											
												
													
														|  |  	}
 |  |  	}
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	p := filepath.Join(sshPath, "authorized_keys")
 |  |  	p := filepath.Join(sshPath, "authorized_keys")
 | 
											
												
													
														|  |  	tmpP := filepath.Join(sshPath, "authorized_keys.tmp")
 |  |  	tmpP := filepath.Join(sshPath, "authorized_keys.tmp")
 | 
											
												
													
														|  | -	log.Trace("ssh.DeletePublicKey(authorized_keys): %s", p)
 |  | 
 | 
											
												
													
														|  | 
 |  | +	log.Trace("publickey.DeletePublicKey(authorized_keys): %s", p)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  	if err = rewriteAuthorizedKeys(key, p, tmpP); err != nil {
 |  |  	if err = rewriteAuthorizedKeys(key, p, tmpP); err != nil {
 | 
											
												
													
														|  |  		return err
 |  |  		return err
 |