Browse Source

Merge commit 'b0bba5b4b02979331d50b031ce018559c8af4d42' into wip-tls

wip-tls
Jaeha Choi 2 months ago
parent
commit
710366d872
Signed by: Jaeha.Choi GPG Key ID: D6133D1D117FF503
  1. 9
      pkg/cryptography/aes_gcm_chunks.go
  2. 33
      pkg/log/log.go
  3. 16
      pkg/log/log_test.go
  4. 93
      pkg/util/util.go
  5. 108
      pkg/util/util_test.go

9
pkg/cryptography/aes_gcm_chunks.go

@ -16,9 +16,10 @@ import (
)
const (
// ChunkSize should be less than max value of uint32 (4294967295)
// ChunkSize is a size of each file chunks in bytes.
// Should be less than max value of uint32 (4294967295)
// since the util package use unsigned 4 bytes to represent the data size.
ChunkSize = 1.28e+8
ChunkSize = 16777216 // 2^24 bytes, about 16.7 MB
IvSize = 12
// MaxFileSize indicates theoretical limit for the file size. Because chunk number are
@ -97,8 +98,8 @@ func DecryptSetup() (ag *AesGcmChunk, err error) {
log.Error("Error while creating download directory")
return nil, err
}
// Create temporary file for receiving
tmpFile, err := ioutil.TempFile(util.DownloadPath, ".tmp_download_")
// Create file for decrypted data
tmpFile, err := ioutil.TempFile(util.DownloadPath, ".tmp_decrypted_")
if err != nil {
log.Debug(err)
log.Error("Temp file could not be created")

33
pkg/log/log.go

@ -37,6 +37,13 @@ func Debug(msg ...interface{}) {
}
}
// Debugf logs if LoggingMode is set to DEBUG or lower
func Debugf(format string, msg ...interface{}) {
if mode <= DEBUG {
_ = logger.Output(2, "DEBUG:\t"+fmt.Sprintf(format, msg...))
}
}
// Info logs if LoggingMode is set to INFO or lower
func Info(msg ...interface{}) {
if mode <= INFO {
@ -44,6 +51,13 @@ func Info(msg ...interface{}) {
}
}
// Infof logs if LoggingMode is set to INFO or lower
func Infof(format string, msg ...interface{}) {
if mode <= INFO {
_ = logger.Output(2, "INFO:\t"+fmt.Sprintf(format, msg...))
}
}
// Warning logs if LoggingMode is set to WARNING or lower
func Warning(msg ...interface{}) {
if mode <= WARNING {
@ -51,6 +65,13 @@ func Warning(msg ...interface{}) {
}
}
// Warningf logs if LoggingMode is set to WARNING or lower
func Warningf(format string, msg ...interface{}) {
if mode <= WARNING {
_ = logger.Output(2, "WARNING:\t"+fmt.Sprintf(format, msg...))
}
}
// Error logs if LoggingMode is set to ERROR or lower
func Error(msg ...interface{}) {
if mode <= ERROR {
@ -58,7 +79,19 @@ func Error(msg ...interface{}) {
}
}
// Errorf logs if LoggingMode is set to ERROR or lower
func Errorf(format string, msg ...interface{}) {
if mode <= ERROR {
_ = logger.Output(2, "Error:\t"+fmt.Sprintf(format, msg...))
}
}
// Fatal always logs when used
func Fatal(msg ...interface{}) {
_ = logger.Output(2, "FATAL:\t"+fmt.Sprint(msg...))
}
// Fatalf always logs when used
func Fatalf(format string, msg ...interface{}) {
_ = logger.Output(2, "FATAL:\t"+fmt.Sprintf(format, msg...))
}

16
pkg/log/log_test.go

@ -6,7 +6,7 @@ import (
"testing"
)
func TestInit(t *testing.T) {
func init() {
initTesting(os.Stdout, DEBUG)
}
@ -29,6 +29,20 @@ func TestDebug(t *testing.T) {
}
}
func TestDebugf(t *testing.T) {
var buffer bytes.Buffer
initTesting(&buffer, DEBUG)
Debugf(" %s", "test debug")
if "Test: DEBUG:\t test debug\n" != buffer.String() {
t.Error("Output mismatch")
}
buffer.Reset()
Infof("%s1", "test info")
if "Test: INFO:\ttest info1\n" != buffer.String() {
t.Error("Output mismatch")
}
}
func TestInfo(t *testing.T) {
var buffer bytes.Buffer
initTesting(&buffer, INFO)

93
pkg/util/util.go

@ -10,6 +10,7 @@ import (
"net"
"os"
"path/filepath"
"sync"
)
const (
@ -22,6 +23,13 @@ var SizeError = errors.New("size exceeded")
var EmptyFileName = errors.New("empty filename")
var bufPool = sync.Pool{
New: func() interface{} {
b := make([]byte, BufferSize)
return b
},
}
// ReadString reads string from a connection
func ReadString(reader io.Reader) (string, error) {
// Read packet size (string size)
@ -74,6 +82,25 @@ func ReadBytes(reader io.Reader) (b []byte, err error) {
return b, nil
}
// ReadBytesToWriter reads message from reader and write it to writer.
// First four bytes of reader should be uint32 size of the message,
// represent in big endian.
// Common usage for this function is to read from net.Conn, and write to temp file.
func ReadBytesToWriter(reader io.Reader, writer io.Writer) (n int, err error) {
// Read message size
size, err := readSize(reader)
if err != nil {
log.Debug(err)
return 0, err
}
totalReceived, err := readWrite(reader, writer, size)
if err != nil {
return totalReceived, err
}
return totalReceived, err
}
// ReadBinary reads file name and file content from a connection and save it.
func ReadBinary(reader io.Reader) error {
// Read file name
@ -143,7 +170,7 @@ func WriteBytes(writer io.Writer, b []byte) (n int, err error) {
// total bytes sent = file size.
// writer is likely to be net.Conn. File size cannot exceed max value of uint32
// as of now. We can split files or change the data type to uint64 if time allows.
func WriteBinary(writer io.Writer, filePath string) (uint32, error) {
func WriteBinary(writer io.Writer, filePath string) (int, error) {
// Open source file to send
srcFile, err := os.Open(filePath)
if err != nil {
@ -194,7 +221,7 @@ func WriteBinary(writer io.Writer, filePath string) (uint32, error) {
// Write file to writer
writtenSize, err := readWrite(srcFile, writer, srcFileSize)
if writtenSize != srcFileSize || err != nil {
if err != nil || writtenSize != int(srcFileSize) {
log.Debug(err)
log.Error("Error while writing binary file")
return writtenSize, err
@ -282,7 +309,7 @@ func readNBinary(reader io.Reader, n uint32, fileN string) error {
}
}(tmpFile.Name())
if writtenSize, err := readWrite(reader, tmpFile, n); writtenSize != n || err != nil {
if writtenSize, err := readWrite(reader, tmpFile, n); writtenSize != int(n) || err != nil {
log.Debug(err)
log.Error("Error while reading from reader and writing to temp file")
return err
@ -326,55 +353,37 @@ func readNBytes(reader io.Reader, n uint32) ([]byte, error) {
return buffer, err
}
// readNBytes reads up to nth byte
func readNBytesPointer(reader io.Reader, buffer *[]byte) error {
_, err := io.ReadFull(reader, *buffer)
return err
}
// readWrite is a helper function to read exactly size bytes from reader and write it to writer.
// Returns length of bytes written and error, if any. Error = nil only if length of bytes
// written = size.
func readWrite(reader io.Reader, writer io.Writer, size uint32) (uint32, error) {
var totalReceived uint32 = 0
var receivedLen int
var err error
var buffSize uint32
// Determine buffer size
if size < BufferSize {
buffSize = size
} else {
buffSize = BufferSize
}
// Create buffer
buffer := make([]byte, buffSize)
// Repeat downloading until the file is fully received
for totalReceived < size {
// Last portion of the data
if totalReceived+buffSize > size {
buffer, err = io.ReadAll(io.LimitReader(reader, int64(size-totalReceived)))
receivedLen = len(buffer)
// If reader contains less than expected size
if totalReceived+uint32(receivedLen) != size {
log.Error("File not fully received")
return totalReceived + uint32(receivedLen), errors.New("unexpected EOF")
}
} else {
receivedLen, err = io.ReadFull(reader, buffer)
func readWrite(reader io.Reader, writer io.Writer, size uint32) (int, error) {
totalReceived := 0
intSize := int(size)
readSize := BufferSize
buffer := bufPool.Get().([]byte)
for totalReceived < intSize {
if totalReceived+BufferSize > intSize {
readSize = intSize - totalReceived
}
if err != nil {
read, err := io.ReadFull(reader, buffer[:readSize])
if err != nil || read != readSize {
log.Debug(err)
log.Error("Error while receiving bytes")
return totalReceived, err
}
// Write to writer
writtenLen, err := writer.Write(buffer)
if writtenLen != receivedLen || err != nil {
written, err := writer.Write(buffer[:readSize])
totalReceived += written
if err != nil {
log.Debug(err)
log.Error("Error while writing to a file")
return totalReceived + uint32(writtenLen), err
return totalReceived, err
}
totalReceived += uint32(receivedLen)
}
bufPool.Put(buffer)
return totalReceived, nil
}

108
pkg/util/util_test.go

@ -1103,3 +1103,111 @@ func CleanupHelper() {
log.Error("Existing directory not deleted, perhaps it does not exist?")
}
}
func TestReadBytesTemp(t *testing.T) {
var buf bytes.Buffer
var output bytes.Buffer
err := writeSize(&buf, 4100)
if err != nil {
t.Error()
return
}
testByte, err := ioutil.ReadFile("../testdata/test_4096.txt")
if err != nil {
t.Error(err)
}
buf.Write(testByte)
testByte = []byte("test")
buf.Write(testByte)
temp, err := ReadBytesToWriter(&buf, &output)
if err != nil || temp != 4100 {
t.Error(err)
return
}
}
func BenchmarkReadNBytes(b *testing.B) {
var buf bytes.Buffer
//testByte, err := ioutil.ReadFile("../testdata/test_4096.txt")
//if err != nil{
// b.Error(err)
//}
testByte := []byte("test")
for i := 0; i < b.N; i++ {
buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
_, err := readNBytes(&buf, 4)
if err != nil {
b.Error(err)
}
}
}
func BenchmarkReadNBytesPointer(b *testing.B) {
var buf bytes.Buffer
buffer := make([]byte, 4)
//testByte, err := ioutil.ReadFile("../testdata/test_4096.txt")
//if err != nil{
// b.Error(err)
//}
testByte := []byte("test")
for i := 0; i < b.N; i++ {
buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
//buf.Write(testByte)
err := readNBytesPointer(&buf, &buffer)
if err != nil {
b.Error(err)
}
}
}
func BenchmarkReadWrite(b *testing.B) {
var buf bytes.Buffer
for i := 0; i < b.N; i++ {
testFile, err := os.Open("../testdata/test_4096.txt")
if err != nil {
b.Error(err)
}
if _, err = readWrite(testFile, &buf, 4096); err != nil {
b.Error(err)
}
}
}
func BenchmarkReadBytesTemp(b *testing.B) {
var buf bytes.Buffer
var output bytes.Buffer
for i := 0; i < b.N; i++ {
err := writeSize(&buf, 4100)
if err != nil {
b.Error()
return
}
testByte, err := ioutil.ReadFile("../testdata/test_4096.txt")
if err != nil {
b.Error(err)
}
buf.Write(testByte)
buf.Write([]byte("test"))
temp, err := ReadBytesToWriter(&buf, &output)
if err != nil || temp != 4100 {
b.Error(err)
return
}
}
}

Loading…
Cancel
Save