Commit 7a7f1997 authored by Andreas Kempe's avatar Andreas Kempe
Browse files

Add zlib compression

This refactors the code to add zlib compression. A class called
SegmentParser is added to handle segmenting of the TCP stream so that
compressed chunks can be correctly decompressed. The packet format is
really simple:

+--------+--------------+
| LEN: 4 | PAYLOAD: LEN |
+--------+--------------+

LEN is a four byte field describing how long the payload is.
PAYLOAD is LEN bytes of compressed data.

Compressors can be added via the compressor interface.
parent 7a942d6f
socker
libsockserr.so
*.so
......@@ -4,11 +4,14 @@ CFLAGS ?=
.PHONY: clean
.SILENT: clean
socker: *.pony libsockserr.so
socker: *.pony compressors/*.pony libsockserr.so libcompressor.so
ponyc -p $(.CURDIR)
libsockserr.so: socks_error.c
$(CC) -shared -fPIC $(CFLAGS) -o $(.TARGET) $(.ALLSRC)
libcompressor.so: compressors/zlib.c
$(CC) -shared -fPIC -lz $(CFLAGS) -o $(.TARGET) $(.ALLSRC)
clean:
- rm libsockserr.so socker 2>&1
- rm -f libsockserr.so libcompressor.so socker 2>&1
use "compressors"
use "net"
class ClientAcceptor is TCPListenNotify
......@@ -55,12 +56,17 @@ actor Client is ProxyInterface
var _local_buffer: Array[U8]
var _remote_buffer: Array[U8]
let _compressor: Compressor
let _segment_parser: SegmentParser
new create(out: OutStream) =>
_out = out
_local = None
_remote = None
_local_buffer = Array[U8]
_remote_buffer = Array[U8]
_compressor = Zlib(9)
_segment_parser = SegmentParser
fun _get_opposite(conn: TCPConnection tag) : TCPConnection tag? =>
if _local is conn then
......@@ -83,16 +89,53 @@ actor Client is ProxyInterface
send_buffer(conn, _local_buffer)
end
fun send_buffer(conn: TCPConnection tag, buffer: Array[U8]) =>
fun ref send_to_remote(data: Array[U8] val)? =>
let remote = _remote as TCPConnection
try
let d: Array[U8] val = _compressor.compress(data)?
let segment: Array[U8] val = _segment_parser.create_segment(d)
remote.write(segment)
else
_out.print("Error compressing outgoing data.")
end
fun ref send_to_local(data: Array[U8] val)? =>
let local = _local as TCPConnection
var index: USize = 0
var buffer: Array[U8] val =
_segment_parser.parse_segment(data)
repeat
if buffer.size() != 0 then
try
local.write(_compressor.decompress(buffer)?)
else
_out.print("Error decompressing incoming data.")
end
buffer = _segment_parser.parse_segment(recover Array[U8] end)
end
until buffer.size() == 0 end
fun ref send_buffer(conn: TCPConnection tag, buffer: Array[U8]) =>
let bs = buffer.size()
let data : Array[U8] iso = recover Array[U8](bs) end
let data: Array[U8] iso = recover Array[U8](bs) end
var compressed: Array[U8] iso = recover Array[U8] end
for d in buffer.values() do
data.push(d)
end
buffer.truncate(0)
conn.write(consume data)
try
if conn is _remote then
send_to_remote(consume data)?
elseif conn is _local then
send_to_local(consume data)?
end
else
_out.print("Error sending outgoing data.")
end
be kill(caller: TCPConnection tag) =>
try
......@@ -103,7 +146,16 @@ actor Client is ProxyInterface
be relay(from: TCPConnection tag, data: Array[U8 val] val) =>
try
_get_opposite(from)?.write(data)
/*
* It is possible that the opposite connection we're
* sending to is None. In that case the else clause of the
* try will buffer the data for later sending.
*/
if from is _local then
send_to_remote(data)?
elseif from is _remote then
send_to_local(data)?
end
else
/*
* If we couldn't get the opposite connection, that means
......
interface Compressor
fun compress(data: Array[U8] val): Array[U8] iso^?
fun decompress(data: Array[U8] val): Array[U8] iso^?
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <zlib.h>
z_stream* pony_deflateInit(uint8_t* data, uint32_t size, uint32_t level)
{
z_stream *strm = malloc(sizeof(strm));
if (!strm)
return NULL;
strm->zalloc = Z_NULL;
strm->zfree = Z_NULL;
strm->opaque = Z_NULL;
if (deflateInit(strm, level) != Z_OK)
return NULL;
strm->avail_in = size;
strm->next_in = data;
return strm;
}
uint8_t* pony_compress(z_stream *strm, uint32_t *o_size, uint32_t *remaining)
{
int ret;
uint8_t *buf = malloc(*o_size);
if (!buf)
goto clean;
strm->avail_out = *o_size;
strm->next_out = buf;
if ((ret = deflate(strm, Z_FINISH)) == Z_STREAM_ERROR)
{
free(buf);
buf = NULL;
goto clean;
}
*o_size = *o_size - strm->avail_out;
*remaining = strm->avail_in;
if (ret == Z_STREAM_END)
goto clean;
goto out;
clean:
deflateEnd(strm);
free(strm);
out:
return buf;
}
z_stream* pony_inflateInit(uint8_t* data, uint32_t size)
{
z_stream *strm = malloc(sizeof(strm));
if (!strm)
return NULL;
strm->zalloc = Z_NULL;
strm->zfree = Z_NULL;
strm->opaque = Z_NULL;
if (inflateInit(strm) != Z_OK)
return NULL;
strm->avail_in = size;
strm->next_in = data;
return strm;
}
uint8_t* pony_decompress(z_stream *strm, uint32_t *o_size, uint32_t *remaining)
{
int ret;
uint8_t *buf = malloc(*o_size);
if (!buf)
goto clean;
strm->avail_out = *o_size;
strm->next_out = buf;
if ((ret = inflate(strm, Z_NO_FLUSH)) == Z_STREAM_ERROR)
{
free(buf);
buf = NULL;
goto clean;
}
switch (ret) {
case Z_NEED_DICT:
case Z_DATA_ERROR:
case Z_MEM_ERROR:
buf = NULL;
goto clean;
}
*o_size = *o_size - strm->avail_out;
*remaining = strm->avail_in;
if (ret == Z_STREAM_END)
goto clean;
goto out;
clean:
inflateEnd(strm);
free(strm);
out:
return buf;
}
use "lib:compressor"
use @pony_deflateInit[Pointer[ZStream]](data: Pointer[U8] tag,
size: U32, level: U8)
use @pony_compress[Pointer[U8]](z_stream: Pointer[ZStream],
o_size: Pointer[U32],
remaining: Pointer[U32])
use @pony_inflateInit[Pointer[ZStream]](data: Pointer[U8] tag, size: U32)
use @pony_decompress[Pointer[U8]](z_stream: Pointer[ZStream],
o_size: Pointer[U32],
remaining: Pointer[U32])
struct ZStream
class Zlib is Compressor
let _level: U8
let _chunk_size: U32 = 4000000
new create(level: U8) =>
_level = level
fun compress(data: Array[U8] val): Array[U8] iso^? =>
var z_stream: Pointer[ZStream]
var compressed: Array[U8] iso = recover Array[U8] end
z_stream = @pony_deflateInit(data.cpointer(), data.size().u32(), _level)
if z_stream.is_null() then
error
end
var remaining: U32 = 0
repeat
var o_size: U32 = _chunk_size
let d = @pony_compress(z_stream, addressof o_size,
addressof remaining)
if d.is_null() then
error
end
let buffer: Array[U8] = Array[U8].from_cpointer(d,
o_size.usize(),
_chunk_size.usize())
for i in buffer.values() do
compressed.push(i)
end
until remaining == 0 end
consume compressed
fun decompress(data: Array[U8] val): Array[U8] iso^? =>
var z_stream: Pointer[ZStream]
var decompressed: Array[U8] iso = recover Array[U8] end
z_stream = @pony_inflateInit(data.cpointer(), data.size().u32())
if z_stream.is_null() then
error
end
var remaining: U32 = 0
repeat
var o_size: U32 = _chunk_size
let d = @pony_decompress(z_stream, addressof o_size,
addressof remaining)
if d.is_null() then
error
end
let buffer: Array[U8] = Array[U8].from_cpointer(d,
o_size.usize(),
_chunk_size.usize())
for i in buffer.values() do
decompressed.push(i)
end
until remaining == 0 end
consume decompressed
......@@ -88,7 +88,7 @@ actor Proxy is ProxyInterface
res = res + (address and 0xFFFF).string()
res
fun ref _new_connection(msg: SocksRequestMessage) =>
fun ref _new_connection(msg: SocksRequestMessage val) =>
var address: String
address =
......@@ -118,25 +118,26 @@ actor Proxy is ProxyInterface
try
let r = _socks_relay as SocksRelay
match r.parse(data)?
| None => return
| let result: SocksRequestMessage =>
_new_connection(result)
for msg in r.parse(data)?.values() do
match msg
| let req: SocksRequestMessage val =>
_new_connection(req)
end
end
else
_out.print("Failed to parse data from local client")
end
be socks_failed(conn: TCPConnection tag, err: U8) =>
be socks_relay(data: Array[U8 val] val) =>
try
let r = _socks_relay as SocksRelay
r.connection_failed(err)
else
_out.print("Connection failed, but no relay active!")
r.forward_to_local(data)
end
be socks_relay(data: Array[U8 val] val) =>
be socks_failed(conn: TCPConnection tag, err: U8) =>
try
let r = _socks_relay as SocksRelay
r.forward_to_local(data)
r.connection_failed(err)
else
_out.print("Connection failed, but no relay active!")
end
primitive SegmentLength
fun string(): String val => "SegmentLength"
primitive SegmentData
fun string(): String val => "SegmentData"
type ParseState is (SegmentLength | SegmentData)
class SegmentParser
var _segment_length: U32
var _input_buffer: Array[U8] iso
var _segment_buffer: Array[U8] iso
var _parse_state: ParseState
var _field_size: USize
new create() =>
_segment_length = 0
_input_buffer = recover Array[U8] end
_segment_buffer = recover Array[U8] end
_parse_state = SegmentLength
_field_size = 4
fun ref change_state(to: ParseState) =>
_parse_state = to
match to
| let s: SegmentLength =>
_field_size = 4
_segment_length = 0
| let s: SegmentData =>
_field_size = _segment_length.usize()
end
fun ref parse_segment(data: Array[U8] val): Array[U8] iso^ =>
_input_buffer.reserve(_input_buffer.size() + data.size())
for i in data.values() do
_input_buffer.push(i)
end
var segment: Array[U8] iso = recover Array[U8] end
match _parse_state
| let s: SegmentLength =>
try
while (_field_size > 0) do
_segment_length = _segment_length +
(_input_buffer.shift()?.u32() <<
(8 * (_field_size.u32() - 1)))
_field_size = _field_size - 1
end
change_state(SegmentData)
// If we still have data. Continue parsing.
if _input_buffer.size() > 0 then
segment = parse_segment(recover Array[U8] end)
end
end
| let s: SegmentData =>
try
while _field_size > 0 do
_segment_buffer.push(_input_buffer.shift()?)
_field_size = _field_size - 1
end
change_state(SegmentLength)
// Return _segment_buffer and allocate new buffer.
segment = _segment_buffer = recover Array[U8] end
end
end
consume segment
fun create_segment(data: Array[U8] val): Array[U8] val =>
let b: Array[U8] iso = recover Array[U8] end
var i: USize = 0
while i < 4 do
b.push((data.size().u32() >> (8 * (4 - i - 1).u32())).u8())
i = i + 1
end
b.copy_from(data, 0, 4, data.size())
consume b
use "compressors"
use "net"
primitive AuthNone
......@@ -73,12 +74,11 @@ class SocksMethodSelectionMessage
new create(method': SocksAuthMethod) =>
method = method'
fun send(conn: TCPConnection) =>
fun serialise(): Array[U8] val =>
var data: Array[U8] iso = recover Array[U8](2) end
data.push(5)
data.push(method())
conn.write(consume data)
consume data
primitive SocksCmdConnect
fun apply(): U8 => 0x01
......@@ -175,7 +175,7 @@ class SocksRequestMessage
let destination_address: SocksAddress
let destination_port: U16
new create(data: Array[U8] val)? =>
new create(data: Array[U8] val, out: OutStream)? =>
// VER
if data(0)? != version then
error
......@@ -280,7 +280,7 @@ class SocksRequestReplyMessage
bind_address = bind_address'
bind_port = bind_port'
fun send(conn: TCPConnection tag) =>
fun serialise(): Array[U8] val =>
var msg: Array[U8] iso = recover Array[U8](4) end
msg.push(version)
msg.push(reply())
......@@ -308,7 +308,7 @@ class SocksRequestReplyMessage
msg.push((bind_port >> 8).u8())
msg.push(bind_port.u8())
conn.write(consume msg)
consume msg
type SocksMessage is (SocksConnectionRequest |
SocksMethodSelectionMessage |
......@@ -329,29 +329,79 @@ class SocksRelay
var _connection_state: SocksConnectionState = SocksUninitialised
let _compressor: Compressor
let _segment_parser: SegmentParser
new create(out: OutStream, conn: TCPConnection tag) =>
_out = out
_connection = conn
_remote_connection = None
_compressor = Zlib(9)
_segment_parser = SegmentParser
fun ref _handle_segment(data: Array[U8] val): (SocksMessage val | None)? =>
var decompressed: Array[U8] iso = recover Array[U8] end
try
decompressed = _compressor.decompress(data)?
else
_out.print("Failed to decompress incoming data")
error
end
fun ref parse(data: Array[U8] val): (SocksMessage | None)? =>
match _connection_state
| SocksUninitialised =>
let req = SocksConnectionRequest(data)?
// Request coming from client to SOCKS server
let req = SocksConnectionRequest(consume decompressed)?
_authenticate(req)?
None
| SocksAuthenticated =>
SocksRequestMessage(data)?
let msg: SocksRequestMessage val =
recover
SocksRequestMessage(consume decompressed, _out)?
end
msg
| SocksForwarding =>
// Data coming from client to be forwarded to the end
// point.
try
let r_conn = _remote_connection as TCPConnection
r_conn.write(data)
r_conn.write(consume decompressed)
end
None
else
_out.print("SocksRelay: Unknown connection state")
end
fun forward_to_local(data: Array[U8] val) =>
_connection.write(data)
fun ref parse(data: Array[U8] val): Array[SocksMessage val] val? =>
var ret: Array[SocksMessage val] iso = recover Array[SocksMessage val] end
// When the index is zero there is no data left in the buffer
// and we have extracted it all.
var buffer: Array[U8] val = _segment_parser.parse_segment(data)
repeat
if buffer.size() != 0 then
let msg = _handle_segment(buffer)?
match msg
| None => None
| let m: SocksRequestMessage val =>
ret.push(m)
end
end
buffer = _segment_parser.parse_segment(recover Array[U8] end)
until buffer.size() == 0 end
consume ret
fun ref forward_to_local(data: Array[U8] val) =>
try
let d: Array[U8] val = _compressor.compress(data)?
let segment: Array[U8] val = _segment_parser.create_segment(d)
_connection.write(segment)
else
_out.print("Could not forward data to client. Compression failed.")
return
end
fun ref _authenticate(req: SocksConnectionRequest)? =>
match req.method
......@@ -361,10 +411,10 @@ class SocksRelay
// specification.
_connection_state = SocksAuthenticated
_out.print("AuthNone chosen. Sending reply.")
SocksMethodSelectionMessage(m).send(_connection)
forward_to_local(SocksMethodSelectionMessage(m).serialise())
else
_out.print("No supported authentication method acceptable to client")
SocksMethodSelectionMessage(AuthNotFound).send(_connection)
forward_to_local(SocksMethodSelectionMessage(AuthNotFound).serialise())
error
end
......@@ -395,14 +445,14 @@ class SocksRelay
let msg = SocksRequestReplyMessage(SocksSucceeded, addr_type,
addr, local_port)
msg.send(_connection)
forward_to_local(msg.serialise())
fun ref connection_failed(err: U8) =>
_out.print("Connection failed")
try
let msg = SocksRequestReplyMessage(SocksErrors(err)?, SocksAddrIPv4,
U32(0), 0)
msg.send(_connection)
forward_to_local(msg.serialise())
else
_out.print("Could not send error message. Incorrect error
type: " + err.string())
......@@ -413,7 +463,7 @@ class SocksRelay
_out.print("SOCKS command not supported")
let msg = SocksRequestReplyMessage(SocksCommandNotSupported,
SocksAddrIPv4, U32(0), 0)
msg.send(_connection)
forward_to_local(msg.serialise())
kill()
fun ref kill() =>
......