diff --git a/admin.go b/admin.go index e78adeec..24c58323 100644 --- a/admin.go +++ b/admin.go @@ -313,7 +313,7 @@ func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []*url.URL { } if admin.Origins == nil { if addr.isLoopback() { - if addr.IsUnixNetwork() { + if addr.IsUnixNetwork() || addr.IsFdNetwork() { // RFC 2616, Section 14.26: // "A client MUST include a Host header field in all HTTP/1.1 request // messages. If the requested URI does not include an Internet host @@ -351,7 +351,7 @@ func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []*url.URL { uniqueOrigins[net.JoinHostPort("127.0.0.1", addr.port())] = struct{}{} } } - if !addr.IsUnixNetwork() { + if !addr.IsUnixNetwork() && !addr.IsFdNetwork() { uniqueOrigins[addr.JoinHostPort(0)] = struct{}{} } } diff --git a/caddyconfig/httpcaddyfile/addresses.go b/caddyconfig/httpcaddyfile/addresses.go index da51fe9b..1c331ead 100644 --- a/caddyconfig/httpcaddyfile/addresses.go +++ b/caddyconfig/httpcaddyfile/addresses.go @@ -77,10 +77,15 @@ import ( // repetition may be undesirable, so call consolidateAddrMappings() to map // multiple addresses to the same lists of server blocks (a many:many mapping). // (Doing this is essentially a map-reduce technique.) -func (st *ServerType) mapAddressToServerBlocks(originalServerBlocks []serverBlock, +func (st *ServerType) mapAddressToProtocolToServerBlocks(originalServerBlocks []serverBlock, options map[string]any, -) (map[string][]serverBlock, error) { - sbmap := make(map[string][]serverBlock) +) (map[string]map[string][]serverBlock, error) { + addrToProtocolToServerBlocks := map[string]map[string][]serverBlock{} + + type keyWithParsedKey struct { + key caddyfile.Token + parsedKey Address + } for i, sblock := range originalServerBlocks { // within a server block, we need to map all the listener addresses @@ -88,27 +93,48 @@ func (st *ServerType) mapAddressToServerBlocks(originalServerBlocks []serverBloc // will be served by them; this has the effect of treating each // key of a server block as its own, but without having to repeat its // contents in cases where multiple keys really can be served together - addrToKeys := make(map[string][]caddyfile.Token) + addrToProtocolToKeyWithParsedKeys := map[string]map[string][]keyWithParsedKey{} for j, key := range sblock.block.Keys { + parsedKey, err := ParseAddress(key.Text) + if err != nil { + return nil, fmt.Errorf("parsing key: %v", err) + } + parsedKey = parsedKey.Normalize() + // a key can have multiple listener addresses if there are multiple // arguments to the 'bind' directive (although they will all have // the same port, since the port is defined by the key or is implicit // through automatic HTTPS) - addrs, err := st.listenerAddrsForServerBlockKey(sblock, key.Text, options) + listeners, err := st.listenersForServerBlockAddress(sblock, parsedKey, options) if err != nil { return nil, fmt.Errorf("server block %d, key %d (%s): determining listener address: %v", i, j, key.Text, err) } - // associate this key with each listener address it is served on - for _, addr := range addrs { - addrToKeys[addr] = append(addrToKeys[addr], key) + // associate this key with its protocols and each listener address served with them + kwpk := keyWithParsedKey{key, parsedKey} + for addr, protocols := range listeners { + protocolToKeyWithParsedKeys, ok := addrToProtocolToKeyWithParsedKeys[addr] + if !ok { + protocolToKeyWithParsedKeys = map[string][]keyWithParsedKey{} + addrToProtocolToKeyWithParsedKeys[addr] = protocolToKeyWithParsedKeys + } + + // an empty protocol indicates the default, a nil or empty value in the ListenProtocols array + if len(protocols) == 0 { + protocols[""] = struct{}{} + } + for prot := range protocols { + protocolToKeyWithParsedKeys[prot] = append( + protocolToKeyWithParsedKeys[prot], + kwpk) + } } } // make a slice of the map keys so we can iterate in sorted order - addrs := make([]string, 0, len(addrToKeys)) - for k := range addrToKeys { - addrs = append(addrs, k) + addrs := make([]string, 0, len(addrToProtocolToKeyWithParsedKeys)) + for addr := range addrToProtocolToKeyWithParsedKeys { + addrs = append(addrs, addr) } sort.Strings(addrs) @@ -118,85 +144,132 @@ func (st *ServerType) mapAddressToServerBlocks(originalServerBlocks []serverBloc // server block are only the ones which use the address; but // the contents (tokens) are of course the same for _, addr := range addrs { - keys := addrToKeys[addr] - // parse keys so that we only have to do it once - parsedKeys := make([]Address, 0, len(keys)) - for _, key := range keys { - addr, err := ParseAddress(key.Text) - if err != nil { - return nil, fmt.Errorf("parsing key '%s': %v", key.Text, err) - } - parsedKeys = append(parsedKeys, addr.Normalize()) + protocolToKeyWithParsedKeys := addrToProtocolToKeyWithParsedKeys[addr] + + prots := make([]string, 0, len(protocolToKeyWithParsedKeys)) + for prot := range protocolToKeyWithParsedKeys { + prots = append(prots, prot) } - sbmap[addr] = append(sbmap[addr], serverBlock{ - block: caddyfile.ServerBlock{ - Keys: keys, - Segments: sblock.block.Segments, - }, - pile: sblock.pile, - keys: parsedKeys, + sort.Strings(prots) + + protocolToServerBlocks, ok := addrToProtocolToServerBlocks[addr] + if !ok { + protocolToServerBlocks = map[string][]serverBlock{} + addrToProtocolToServerBlocks[addr] = protocolToServerBlocks + } + + for _, prot := range prots { + keyWithParsedKeys := protocolToKeyWithParsedKeys[prot] + + keys := make([]caddyfile.Token, len(keyWithParsedKeys)) + parsedKeys := make([]Address, len(keyWithParsedKeys)) + + for k, keyWithParsedKey := range keyWithParsedKeys { + keys[k] = keyWithParsedKey.key + parsedKeys[k] = keyWithParsedKey.parsedKey + } + + protocolToServerBlocks[prot] = append(protocolToServerBlocks[prot], serverBlock{ + block: caddyfile.ServerBlock{ + Keys: keys, + Segments: sblock.block.Segments, + }, + pile: sblock.pile, + parsedKeys: parsedKeys, + }) + } + } + } + + return addrToProtocolToServerBlocks, nil +} + +// consolidateAddrMappings eliminates repetition of identical server blocks in a mapping of +// single listener addresses to protocols to lists of server blocks. Since multiple addresses +// may serve multiple protocols to identical sites (server block contents), this function turns +// a 1:many mapping into a many:many mapping. Server block contents (tokens) must be +// exactly identical so that reflect.DeepEqual returns true in order for the addresses to be combined. +// Identical entries are deleted from the addrToServerBlocks map. Essentially, each pairing (each +// association from multiple addresses to multiple server blocks; i.e. each element of +// the returned slice) becomes a server definition in the output JSON. +func (st *ServerType) consolidateAddrMappings(addrToProtocolToServerBlocks map[string]map[string][]serverBlock) []sbAddrAssociation { + sbaddrs := make([]sbAddrAssociation, 0, len(addrToProtocolToServerBlocks)) + + addrs := make([]string, 0, len(addrToProtocolToServerBlocks)) + for addr := range addrToProtocolToServerBlocks { + addrs = append(addrs, addr) + } + sort.Strings(addrs) + + for _, addr := range addrs { + protocolToServerBlocks := addrToProtocolToServerBlocks[addr] + + prots := make([]string, 0, len(protocolToServerBlocks)) + for prot := range protocolToServerBlocks { + prots = append(prots, prot) + } + sort.Strings(prots) + + for _, prot := range prots { + serverBlocks := protocolToServerBlocks[prot] + + // now find other addresses that map to identical + // server blocks and add them to our map of listener + // addresses and protocols, while removing them from + // the original map + listeners := map[string]map[string]struct{}{} + + for otherAddr, otherProtocolToServerBlocks := range addrToProtocolToServerBlocks { + for otherProt, otherServerBlocks := range otherProtocolToServerBlocks { + if addr == otherAddr && prot == otherProt || reflect.DeepEqual(serverBlocks, otherServerBlocks) { + listener, ok := listeners[otherAddr] + if !ok { + listener = map[string]struct{}{} + listeners[otherAddr] = listener + } + listener[otherProt] = struct{}{} + delete(otherProtocolToServerBlocks, otherProt) + } + } + } + + addresses := make([]string, 0, len(listeners)) + for lnAddr := range listeners { + addresses = append(addresses, lnAddr) + } + sort.Strings(addresses) + + addressesWithProtocols := make([]addressWithProtocols, 0, len(listeners)) + + for _, lnAddr := range addresses { + lnProts := listeners[lnAddr] + prots := make([]string, 0, len(lnProts)) + for prot := range lnProts { + prots = append(prots, prot) + } + sort.Strings(prots) + + addressesWithProtocols = append(addressesWithProtocols, addressWithProtocols{ + address: lnAddr, + protocols: prots, + }) + } + + sbaddrs = append(sbaddrs, sbAddrAssociation{ + addressesWithProtocols: addressesWithProtocols, + serverBlocks: serverBlocks, }) } } - return sbmap, nil -} - -// consolidateAddrMappings eliminates repetition of identical server blocks in a mapping of -// single listener addresses to lists of server blocks. Since multiple addresses may serve -// identical sites (server block contents), this function turns a 1:many mapping into a -// many:many mapping. Server block contents (tokens) must be exactly identical so that -// reflect.DeepEqual returns true in order for the addresses to be combined. Identical -// entries are deleted from the addrToServerBlocks map. Essentially, each pairing (each -// association from multiple addresses to multiple server blocks; i.e. each element of -// the returned slice) becomes a server definition in the output JSON. -func (st *ServerType) consolidateAddrMappings(addrToServerBlocks map[string][]serverBlock) []sbAddrAssociation { - sbaddrs := make([]sbAddrAssociation, 0, len(addrToServerBlocks)) - for addr, sblocks := range addrToServerBlocks { - // we start with knowing that at least this address - // maps to these server blocks - a := sbAddrAssociation{ - addresses: []string{addr}, - serverBlocks: sblocks, - } - - // now find other addresses that map to identical - // server blocks and add them to our list of - // addresses, while removing them from the map - for otherAddr, otherSblocks := range addrToServerBlocks { - if addr == otherAddr { - continue - } - if reflect.DeepEqual(sblocks, otherSblocks) { - a.addresses = append(a.addresses, otherAddr) - delete(addrToServerBlocks, otherAddr) - } - } - sort.Strings(a.addresses) - - sbaddrs = append(sbaddrs, a) - } - - // sort them by their first address (we know there will always be at least one) - // to avoid problems with non-deterministic ordering (makes tests flaky) - sort.Slice(sbaddrs, func(i, j int) bool { - return sbaddrs[i].addresses[0] < sbaddrs[j].addresses[0] - }) - return sbaddrs } -// listenerAddrsForServerBlockKey essentially converts the Caddyfile -// site addresses to Caddy listener addresses for each server block. -func (st *ServerType) listenerAddrsForServerBlockKey(sblock serverBlock, key string, +// listenersForServerBlockAddress essentially converts the Caddyfile site addresses to a map from +// Caddy listener addresses and the protocols to serve them with to the parsed address for each server block. +func (st *ServerType) listenersForServerBlockAddress(sblock serverBlock, addr Address, options map[string]any, -) ([]string, error) { - addr, err := ParseAddress(key) - if err != nil { - return nil, fmt.Errorf("parsing key: %v", err) - } - addr = addr.Normalize() - +) (map[string]map[string]struct{}, error) { switch addr.Scheme { case "wss": return nil, fmt.Errorf("the scheme wss:// is only supported in browsers; use https:// instead") @@ -230,55 +303,54 @@ func (st *ServerType) listenerAddrsForServerBlockKey(sblock serverBlock, key str // error if scheme and port combination violate convention if (addr.Scheme == "http" && lnPort == httpsPort) || (addr.Scheme == "https" && lnPort == httpPort) { - return nil, fmt.Errorf("[%s] scheme and port violate convention", key) + return nil, fmt.Errorf("[%s] scheme and port violate convention", addr.String()) } - // the bind directive specifies hosts (and potentially network), but is optional - lnHosts := make([]string, 0, len(sblock.pile["bind"])) + // the bind directive specifies hosts (and potentially network), and the protocols to serve them with, but is optional + lnCfgVals := make([]addressesWithProtocols, 0, len(sblock.pile["bind"])) for _, cfgVal := range sblock.pile["bind"] { - lnHosts = append(lnHosts, cfgVal.Value.([]string)...) + if val, ok := cfgVal.Value.(addressesWithProtocols); ok { + lnCfgVals = append(lnCfgVals, val) + } } - if len(lnHosts) == 0 { - if defaultBind, ok := options["default_bind"].([]string); ok { - lnHosts = defaultBind + if len(lnCfgVals) == 0 { + if defaultBindValues, ok := options["default_bind"].([]ConfigValue); ok { + for _, defaultBindValue := range defaultBindValues { + lnCfgVals = append(lnCfgVals, defaultBindValue.Value.(addressesWithProtocols)) + } } else { - lnHosts = []string{""} + lnCfgVals = []addressesWithProtocols{{ + addresses: []string{""}, + protocols: nil, + }} } } // use a map to prevent duplication - listeners := make(map[string]struct{}) - for _, lnHost := range lnHosts { - // normally we would simply append the port, - // but if lnHost is IPv6, we need to ensure it - // is enclosed in [ ]; net.JoinHostPort does - // this for us, but lnHost might also have a - // network type in front (e.g. "tcp/") leading - // to "[tcp/::1]" which causes parsing failures - // later; what we need is "tcp/[::1]", so we have - // to split the network and host, then re-combine - network, host, ok := strings.Cut(lnHost, "/") - if !ok { - host = network - network = "" + listeners := map[string]map[string]struct{}{} + for _, lnCfgVal := range lnCfgVals { + for _, lnHost := range lnCfgVal.addresses { + networkAddr, err := caddy.ParseNetworkAddressFromHostPort(lnHost, lnPort) + if err != nil { + return nil, fmt.Errorf("parsing network address: %v", err) + } + if _, ok := listeners[addr.String()]; !ok { + listeners[networkAddr.String()] = map[string]struct{}{} + } + for _, protocol := range lnCfgVal.protocols { + listeners[networkAddr.String()][protocol] = struct{}{} + } } - host = strings.Trim(host, "[]") // IPv6 - networkAddr := caddy.JoinNetworkAddress(network, host, lnPort) - addr, err := caddy.ParseNetworkAddress(networkAddr) - if err != nil { - return nil, fmt.Errorf("parsing network address: %v", err) - } - listeners[addr.String()] = struct{}{} } - // now turn map into list - listenersList := make([]string, 0, len(listeners)) - for lnStr := range listeners { - listenersList = append(listenersList, lnStr) - } - sort.Strings(listenersList) + return listeners, nil +} - return listenersList, nil +// addressesWithProtocols associates a list of listen addresses +// with a list of protocols to serve them with +type addressesWithProtocols struct { + addresses []string + protocols []string } // Address represents a site address. It contains diff --git a/caddyconfig/httpcaddyfile/builtins.go b/caddyconfig/httpcaddyfile/builtins.go index e1e95e00..165c66b2 100644 --- a/caddyconfig/httpcaddyfile/builtins.go +++ b/caddyconfig/httpcaddyfile/builtins.go @@ -56,10 +56,30 @@ func init() { // parseBind parses the bind directive. Syntax: // -// bind +// bind [{ +// protocols [h1|h2|h2c|h3] [...] +// }] func parseBind(h Helper) ([]ConfigValue, error) { h.Next() // consume directive name - return []ConfigValue{{Class: "bind", Value: h.RemainingArgs()}}, nil + var addresses, protocols []string + addresses = h.RemainingArgs() + + for h.NextBlock(0) { + switch h.Val() { + case "protocols": + protocols = h.RemainingArgs() + if len(protocols) == 0 { + return nil, h.Errf("protocols requires one or more arguments") + } + default: + return nil, h.Errf("unknown subdirective: %s", h.Val()) + } + } + + return []ConfigValue{{Class: "bind", Value: addressesWithProtocols{ + addresses: addresses, + protocols: protocols, + }}}, nil } // parseTLS parses the tls directive. Syntax: diff --git a/caddyconfig/httpcaddyfile/directives.go b/caddyconfig/httpcaddyfile/directives.go index 19ef4bc0..f0687a7e 100644 --- a/caddyconfig/httpcaddyfile/directives.go +++ b/caddyconfig/httpcaddyfile/directives.go @@ -516,9 +516,9 @@ func sortRoutes(routes []ConfigValue) { // a "pile" of config values, keyed by class name, // as well as its parsed keys for convenience. type serverBlock struct { - block caddyfile.ServerBlock - pile map[string][]ConfigValue // config values obtained from directives - keys []Address + block caddyfile.ServerBlock + pile map[string][]ConfigValue // config values obtained from directives + parsedKeys []Address } // hostsFromKeys returns a list of all the non-empty hostnames found in @@ -535,7 +535,7 @@ type serverBlock struct { func (sb serverBlock) hostsFromKeys(loggerMode bool) []string { // ensure each entry in our list is unique hostMap := make(map[string]struct{}) - for _, addr := range sb.keys { + for _, addr := range sb.parsedKeys { if addr.Host == "" { if !loggerMode { // server block contains a key like ":443", i.e. the host portion @@ -567,7 +567,7 @@ func (sb serverBlock) hostsFromKeys(loggerMode bool) []string { func (sb serverBlock) hostsFromKeysNotHTTP(httpPort string) []string { // ensure each entry in our list is unique hostMap := make(map[string]struct{}) - for _, addr := range sb.keys { + for _, addr := range sb.parsedKeys { if addr.Host == "" { continue } @@ -588,7 +588,7 @@ func (sb serverBlock) hostsFromKeysNotHTTP(httpPort string) []string { // hasHostCatchAllKey returns true if sb has a key that // omits a host portion, i.e. it "catches all" hosts. func (sb serverBlock) hasHostCatchAllKey() bool { - return slices.ContainsFunc(sb.keys, func(addr Address) bool { + return slices.ContainsFunc(sb.parsedKeys, func(addr Address) bool { return addr.Host == "" }) } @@ -596,7 +596,7 @@ func (sb serverBlock) hasHostCatchAllKey() bool { // isAllHTTP returns true if all sb keys explicitly specify // the http:// scheme func (sb serverBlock) isAllHTTP() bool { - return !slices.ContainsFunc(sb.keys, func(addr Address) bool { + return !slices.ContainsFunc(sb.parsedKeys, func(addr Address) bool { return addr.Scheme != "http" }) } diff --git a/caddyconfig/httpcaddyfile/directives_test.go b/caddyconfig/httpcaddyfile/directives_test.go index db028229..2b4d3e6c 100644 --- a/caddyconfig/httpcaddyfile/directives_test.go +++ b/caddyconfig/httpcaddyfile/directives_test.go @@ -78,7 +78,7 @@ func TestHostsFromKeys(t *testing.T) { []string{"example.com:2015"}, }, } { - sb := serverBlock{keys: tc.keys} + sb := serverBlock{parsedKeys: tc.keys} // test in normal mode actual := sb.hostsFromKeys(false) diff --git a/caddyconfig/httpcaddyfile/httptype.go b/caddyconfig/httpcaddyfile/httptype.go index c858ee56..6745969e 100644 --- a/caddyconfig/httpcaddyfile/httptype.go +++ b/caddyconfig/httpcaddyfile/httptype.go @@ -171,7 +171,7 @@ func (st ServerType) Setup( } // map - sbmap, err := st.mapAddressToServerBlocks(originalServerBlocks, options) + sbmap, err := st.mapAddressToProtocolToServerBlocks(originalServerBlocks, options) if err != nil { return nil, warnings, err } @@ -402,6 +402,20 @@ func (ServerType) evaluateGlobalOptionsBlock(serverBlocks []serverBlock, options options[opt] = append(existingOpts, logOpts...) continue } + // Also fold multiple "default_bind" options together into an + // array so that server blocks can have multiple binds by default. + if opt == "default_bind" { + existingOpts, ok := options[opt].([]ConfigValue) + if !ok { + existingOpts = []ConfigValue{} + } + defaultBindOpts, ok := val.([]ConfigValue) + if !ok { + return nil, fmt.Errorf("unexpected type from 'default_bind' global options: %T", val) + } + options[opt] = append(existingOpts, defaultBindOpts...) + continue + } options[opt] = val } @@ -543,8 +557,40 @@ func (st *ServerType) serversFromPairings( } } + var ( + addresses []string + protocols [][]string + ) + + for _, addressWithProtocols := range p.addressesWithProtocols { + addresses = append(addresses, addressWithProtocols.address) + protocols = append(protocols, addressWithProtocols.protocols) + } + srv := &caddyhttp.Server{ - Listen: p.addresses, + Listen: addresses, + ListenProtocols: protocols, + } + + // remove srv.ListenProtocols[j] if it only contains the default protocols + for j, lnProtocols := range srv.ListenProtocols { + srv.ListenProtocols[j] = nil + for _, lnProtocol := range lnProtocols { + if lnProtocol != "" { + srv.ListenProtocols[j] = lnProtocols + break + } + } + } + + // remove srv.ListenProtocols if it only contains the default protocols for all listen addresses + listenProtocols := srv.ListenProtocols + srv.ListenProtocols = nil + for _, lnProtocols := range listenProtocols { + if lnProtocols != nil { + srv.ListenProtocols = listenProtocols + break + } } // handle the auto_https global option @@ -566,7 +612,7 @@ func (st *ServerType) serversFromPairings( // See ParseAddress() where parsing should later reject paths // See https://github.com/caddyserver/caddy/pull/4728 for a full explanation for _, sblock := range p.serverBlocks { - for _, addr := range sblock.keys { + for _, addr := range sblock.parsedKeys { if addr.Path != "" { caddy.Log().Named("caddyfile").Warn("Using a path in a site address is deprecated; please use the 'handle' directive instead", zap.String("address", addr.String())) } @@ -584,7 +630,7 @@ func (st *ServerType) serversFromPairings( var iLongestPath, jLongestPath string var iLongestHost, jLongestHost string var iWildcardHost, jWildcardHost bool - for _, addr := range p.serverBlocks[i].keys { + for _, addr := range p.serverBlocks[i].parsedKeys { if strings.Contains(addr.Host, "*") || addr.Host == "" { iWildcardHost = true } @@ -595,7 +641,7 @@ func (st *ServerType) serversFromPairings( iLongestPath = addr.Path } } - for _, addr := range p.serverBlocks[j].keys { + for _, addr := range p.serverBlocks[j].parsedKeys { if strings.Contains(addr.Host, "*") || addr.Host == "" { jWildcardHost = true } @@ -711,7 +757,7 @@ func (st *ServerType) serversFromPairings( } } - for _, addr := range sblock.keys { + for _, addr := range sblock.parsedKeys { // if server only uses HTTP port, auto-HTTPS will not apply if listenersUseAnyPortOtherThan(srv.Listen, httpPort) { // exclude any hosts that were defined explicitly with "http://" @@ -886,8 +932,7 @@ func (st *ServerType) serversFromPairings( servers[fmt.Sprintf("srv%d", i)] = srv } - err := applyServerOptions(servers, options, warnings) - if err != nil { + if err := applyServerOptions(servers, options, warnings); err != nil { return nil, fmt.Errorf("applying global server options: %v", err) } @@ -932,7 +977,7 @@ func detectConflictingSchemes(srv *caddyhttp.Server, serverBlocks []serverBlock, } for _, sblock := range serverBlocks { - for _, addr := range sblock.keys { + for _, addr := range sblock.parsedKeys { if addr.Scheme == "http" || addr.Port == httpPort { if err := checkAndSetHTTP(addr); err != nil { return err @@ -1322,7 +1367,7 @@ func (st *ServerType) compileEncodedMatcherSets(sblock serverBlock) ([]caddy.Mod var matcherPairs []*hostPathPair var catchAllHosts bool - for _, addr := range sblock.keys { + for _, addr := range sblock.parsedKeys { // choose a matcher pair that should be shared by this // server block; if none exists yet, create one var chosenMatcherPair *hostPathPair @@ -1594,12 +1639,19 @@ type namedCustomLog struct { noHostname bool } +// addressWithProtocols associates a listen address with +// the protocols to serve it with +type addressWithProtocols struct { + address string + protocols []string +} + // sbAddrAssociation is a mapping from a list of -// addresses to a list of server blocks that are -// served on those addresses. +// addresses with protocols, and a list of server +// blocks that are served on those addresses. type sbAddrAssociation struct { - addresses []string - serverBlocks []serverBlock + addressesWithProtocols []addressWithProtocols + serverBlocks []serverBlock } const ( diff --git a/caddyconfig/httpcaddyfile/options.go b/caddyconfig/httpcaddyfile/options.go index 53687d32..c14208b6 100644 --- a/caddyconfig/httpcaddyfile/options.go +++ b/caddyconfig/httpcaddyfile/options.go @@ -31,7 +31,7 @@ func init() { RegisterGlobalOption("debug", parseOptTrue) RegisterGlobalOption("http_port", parseOptHTTPPort) RegisterGlobalOption("https_port", parseOptHTTPSPort) - RegisterGlobalOption("default_bind", parseOptStringList) + RegisterGlobalOption("default_bind", parseOptDefaultBind) RegisterGlobalOption("grace_period", parseOptDuration) RegisterGlobalOption("shutdown_delay", parseOptDuration) RegisterGlobalOption("default_sni", parseOptSingleString) @@ -284,13 +284,32 @@ func parseOptSingleString(d *caddyfile.Dispenser, _ any) (any, error) { return val, nil } -func parseOptStringList(d *caddyfile.Dispenser, _ any) (any, error) { +func parseOptDefaultBind(d *caddyfile.Dispenser, _ any) (any, error) { d.Next() // consume option name - val := d.RemainingArgs() - if len(val) == 0 { - return "", d.ArgErr() + + var addresses, protocols []string + addresses = d.RemainingArgs() + + if len(addresses) == 0 { + addresses = append(addresses, "") } - return val, nil + + for d.NextBlock(0) { + switch d.Val() { + case "protocols": + protocols = d.RemainingArgs() + if len(protocols) == 0 { + return nil, d.Errf("protocols requires one or more arguments") + } + default: + return nil, d.Errf("unknown subdirective: %s", d.Val()) + } + } + + return []ConfigValue{{Class: "bind", Value: addressesWithProtocols{ + addresses: addresses, + protocols: protocols, + }}}, nil } func parseOptAdmin(d *caddyfile.Dispenser, _ any) (any, error) { diff --git a/caddyconfig/httpcaddyfile/tlsapp.go b/caddyconfig/httpcaddyfile/tlsapp.go index 157a3113..c6ff81b2 100644 --- a/caddyconfig/httpcaddyfile/tlsapp.go +++ b/caddyconfig/httpcaddyfile/tlsapp.go @@ -57,13 +57,13 @@ func (st ServerType) buildTLSApp( if autoHTTPS != "off" { for _, pair := range pairings { for _, sb := range pair.serverBlocks { - for _, addr := range sb.keys { + for _, addr := range sb.parsedKeys { if addr.Host != "" { continue } // this server block has a hostless key, now // go through and add all the hosts to the set - for _, otherAddr := range sb.keys { + for _, otherAddr := range sb.parsedKeys { if otherAddr.Original == addr.Original { continue } @@ -93,7 +93,11 @@ func (st ServerType) buildTLSApp( for _, p := range pairings { // avoid setting up TLS automation policies for a server that is HTTP-only - if !listenersUseAnyPortOtherThan(p.addresses, httpPort) { + var addresses []string + for _, addressWithProtocols := range p.addressesWithProtocols { + addresses = append(addresses, addressWithProtocols.address) + } + if !listenersUseAnyPortOtherThan(addresses, httpPort) { continue } @@ -183,8 +187,8 @@ func (st ServerType) buildTLSApp( if acmeIssuer.Challenges.BindHost == "" { // only binding to one host is supported var bindHost string - if bindHosts, ok := cfgVal.Value.([]string); ok && len(bindHosts) > 0 { - bindHost = bindHosts[0] + if asserted, ok := cfgVal.Value.(addressesWithProtocols); ok && len(asserted.addresses) > 0 { + bindHost = asserted.addresses[0] } acmeIssuer.Challenges.BindHost = bindHost } diff --git a/caddytest/integration/caddyfile_adapt/bind_fd_fdgram_h123.caddyfiletest b/caddytest/integration/caddyfile_adapt/bind_fd_fdgram_h123.caddyfiletest new file mode 100644 index 00000000..08f30d18 --- /dev/null +++ b/caddytest/integration/caddyfile_adapt/bind_fd_fdgram_h123.caddyfiletest @@ -0,0 +1,142 @@ +{ + auto_https disable_redirects + admin off +} + +http://localhost { + bind fd/{env.CADDY_HTTP_FD} { + protocols h1 + } + log + respond "Hello, HTTP!" +} + +https://localhost { + bind fd/{env.CADDY_HTTPS_FD} { + protocols h1 h2 + } + bind fdgram/{env.CADDY_HTTP3_FD} { + protocols h3 + } + log + respond "Hello, HTTPS!" +} +---------- +{ + "admin": { + "disabled": true + }, + "apps": { + "http": { + "servers": { + "srv0": { + "listen": [ + "fd/{env.CADDY_HTTPS_FD}", + "fdgram/{env.CADDY_HTTP3_FD}" + ], + "routes": [ + { + "match": [ + { + "host": [ + "localhost" + ] + } + ], + "handle": [ + { + "handler": "subroute", + "routes": [ + { + "handle": [ + { + "body": "Hello, HTTPS!", + "handler": "static_response" + } + ] + } + ] + } + ], + "terminal": true + } + ], + "automatic_https": { + "disable_redirects": true + }, + "logs": { + "logger_names": { + "localhost": [ + "" + ] + } + }, + "listen_protocols": [ + [ + "h1", + "h2" + ], + [ + "h3" + ] + ] + }, + "srv1": { + "automatic_https": { + "disable_redirects": true + } + }, + "srv2": { + "listen": [ + "fd/{env.CADDY_HTTP_FD}" + ], + "routes": [ + { + "match": [ + { + "host": [ + "localhost" + ] + } + ], + "handle": [ + { + "handler": "subroute", + "routes": [ + { + "handle": [ + { + "body": "Hello, HTTP!", + "handler": "static_response" + } + ] + } + ] + } + ], + "terminal": true + } + ], + "automatic_https": { + "disable_redirects": true, + "skip": [ + "localhost" + ] + }, + "logs": { + "logger_names": { + "localhost": [ + "" + ] + } + }, + "listen_protocols": [ + [ + "h1" + ] + ] + } + } + } + } +} diff --git a/cmd/commandfuncs.go b/cmd/commandfuncs.go index 49d0321e..50e9b110 100644 --- a/cmd/commandfuncs.go +++ b/cmd/commandfuncs.go @@ -660,6 +660,8 @@ func AdminAPIRequest(adminAddr, method, uri string, headers http.Header, body io return nil, err } parsedAddr.Host = addr + } else if parsedAddr.IsFdNetwork() { + origin = "http://127.0.0.1" } // form the request @@ -667,13 +669,13 @@ func AdminAPIRequest(adminAddr, method, uri string, headers http.Header, body io if err != nil { return nil, fmt.Errorf("making request: %v", err) } - if parsedAddr.IsUnixNetwork() { + if parsedAddr.IsUnixNetwork() || parsedAddr.IsFdNetwork() { // We used to conform to RFC 2616 Section 14.26 which requires // an empty host header when there is no host, as is the case - // with unix sockets. However, Go required a Host value so we - // used a hack of a space character as the host (it would see - // the Host was non-empty, then trim the space later). As of - // Go 1.20.6 (July 2023), this hack no longer works. See: + // with unix sockets and socket fds. However, Go required a + // Host value so we used a hack of a space character as the host + // (it would see the Host was non-empty, then trim the space later). + // As of Go 1.20.6 (July 2023), this hack no longer works. See: // https://github.com/golang/go/issues/60374 // See also the discussion here: // https://github.com/golang/go/issues/61431 diff --git a/cmd/main_test.go b/cmd/main_test.go index 757a58ce..3b2412c5 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -235,7 +235,7 @@ func Test_isCaddyfile(t *testing.T) { wantErr: false, }, { - + name: "json is not caddyfile but not error", args: args{ configFile: "./Caddyfile.json", @@ -245,7 +245,7 @@ func Test_isCaddyfile(t *testing.T) { wantErr: false, }, { - + name: "prefix of Caddyfile and ./ with any extension is Caddyfile", args: args{ configFile: "./Caddyfile.prd", @@ -255,7 +255,7 @@ func Test_isCaddyfile(t *testing.T) { wantErr: false, }, { - + name: "prefix of Caddyfile without ./ with any extension is Caddyfile", args: args{ configFile: "Caddyfile.prd", diff --git a/listen.go b/listen.go index 34812b54..f5c2086a 100644 --- a/listen.go +++ b/listen.go @@ -18,7 +18,11 @@ package caddy import ( "context" + "fmt" "net" + "os" + "slices" + "strconv" "sync" "sync/atomic" "time" @@ -31,10 +35,49 @@ func reuseUnixSocket(network, addr string) (any, error) { } func listenReusable(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (any, error) { - switch network { - case "udp", "udp4", "udp6", "unixgram": + var socketFile *os.File + + fd := slices.Contains([]string{"fd", "fdgram"}, network) + if fd { + socketFd, err := strconv.ParseUint(address, 0, strconv.IntSize) + if err != nil { + return nil, fmt.Errorf("invalid file descriptor: %v", err) + } + + func() { + socketFilesMu.Lock() + defer socketFilesMu.Unlock() + + socketFdWide := uintptr(socketFd) + var ok bool + + socketFile, ok = socketFiles[socketFdWide] + + if !ok { + socketFile = os.NewFile(socketFdWide, lnKey) + if socketFile != nil { + socketFiles[socketFdWide] = socketFile + } + } + }() + + if socketFile == nil { + return nil, fmt.Errorf("invalid socket file descriptor: %d", socketFd) + } + } + + datagram := slices.Contains([]string{"udp", "udp4", "udp6", "unixgram", "fdgram"}, network) + if datagram { sharedPc, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { - pc, err := config.ListenPacket(ctx, network, address) + var ( + pc net.PacketConn + err error + ) + if fd { + pc, err = net.FilePacketConn(socketFile) + } else { + pc, err = config.ListenPacket(ctx, network, address) + } if err != nil { return nil, err } @@ -44,20 +87,27 @@ func listenReusable(ctx context.Context, lnKey string, network, address string, return nil, err } return &fakeClosePacketConn{sharedPacketConn: sharedPc.(*sharedPacketConn)}, nil + } - default: - sharedLn, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { - ln, err := config.Listen(ctx, network, address) - if err != nil { - return nil, err - } - return &sharedListener{Listener: ln, key: lnKey}, nil - }) + sharedLn, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { + var ( + ln net.Listener + err error + ) + if fd { + ln, err = net.FileListener(socketFile) + } else { + ln, err = config.Listen(ctx, network, address) + } if err != nil { return nil, err } - return &fakeCloseListener{sharedListener: sharedLn.(*sharedListener), keepAlivePeriod: config.KeepAlive}, nil + return &sharedListener{Listener: ln, key: lnKey}, nil + }) + if err != nil { + return nil, err } + return &fakeCloseListener{sharedListener: sharedLn.(*sharedListener), keepAlivePeriod: config.KeepAlive}, nil } // fakeCloseListener is a private wrapper over a listener that @@ -260,3 +310,9 @@ var ( Unwrap() net.PacketConn }) = (*fakeClosePacketConn)(nil) ) + +// socketFiles is a fd -> *os.File map used to make a FileListener/FilePacketConn from a socket file descriptor. +var socketFiles = map[uintptr]*os.File{} + +// socketFilesMu synchronizes socketFiles insertions +var socketFilesMu sync.Mutex diff --git a/listen_unix.go b/listen_unix.go index 9ec65c39..d6ae0cb8 100644 --- a/listen_unix.go +++ b/listen_unix.go @@ -22,10 +22,14 @@ package caddy import ( "context" "errors" + "fmt" "io" "io/fs" "net" "os" + "slices" + "strconv" + "sync" "sync/atomic" "syscall" @@ -34,12 +38,9 @@ import ( ) // reuseUnixSocket copies and reuses the unix domain socket (UDS) if we already -// have it open; if not, unlink it so we can have it. No-op if not a unix network. +// have it open; if not, unlink it so we can have it. +// No-op if not a unix network. func reuseUnixSocket(network, addr string) (any, error) { - if !IsUnixNetwork(network) { - return nil, nil - } - socketKey := listenerKey(network, addr) socket, exists := unixSockets[socketKey] @@ -71,7 +72,7 @@ func reuseUnixSocket(network, addr string) (any, error) { return nil, err } atomic.AddInt32(unixSocket.count, 1) - unixSockets[socketKey] = &unixConn{pc.(*net.UnixConn), addr, socketKey, unixSocket.count} + unixSockets[socketKey] = &unixConn{pc.(*net.UnixConn), socketKey, unixSocket.count} } return unixSockets[socketKey], nil @@ -89,56 +90,107 @@ func reuseUnixSocket(network, addr string) (any, error) { return nil, nil } +// listenReusable creates a new listener for the given network and address, and adds it to listenerPool. func listenReusable(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (any, error) { - // wrap any Control function set by the user so we can also add our reusePort control without clobbering theirs - oldControl := config.Control - config.Control = func(network, address string, c syscall.RawConn) error { - if oldControl != nil { - if err := oldControl(network, address, c); err != nil { - return err - } - } - return reusePort(network, address, c) - } - // even though SO_REUSEPORT lets us bind the socket multiple times, // we still put it in the listenerPool so we can count how many // configs are using this socket; necessary to ensure we can know // whether to enforce shutdown delays, for example (see #5393). - var ln io.Closer - var err error - switch network { - case "udp", "udp4", "udp6", "unixgram": - ln, err = config.ListenPacket(ctx, network, address) - default: - ln, err = config.Listen(ctx, network, address) + var ( + ln io.Closer + err error + socketFile *os.File + ) + + fd := slices.Contains([]string{"fd", "fdgram"}, network) + if fd { + socketFd, err := strconv.ParseUint(address, 0, strconv.IntSize) + if err != nil { + return nil, fmt.Errorf("invalid file descriptor: %v", err) + } + + func() { + socketFilesMu.Lock() + defer socketFilesMu.Unlock() + + socketFdWide := uintptr(socketFd) + var ok bool + + socketFile, ok = socketFiles[socketFdWide] + + if !ok { + socketFile = os.NewFile(socketFdWide, lnKey) + if socketFile != nil { + socketFiles[socketFdWide] = socketFile + } + } + }() + + if socketFile == nil { + return nil, fmt.Errorf("invalid socket file descriptor: %d", socketFd) + } + } else { + // wrap any Control function set by the user so we can also add our reusePort control without clobbering theirs + oldControl := config.Control + config.Control = func(network, address string, c syscall.RawConn) error { + if oldControl != nil { + if err := oldControl(network, address, c); err != nil { + return err + } + } + return reusePort(network, address, c) + } } + + datagram := slices.Contains([]string{"udp", "udp4", "udp6", "unixgram", "fdgram"}, network) + if datagram { + if fd { + ln, err = net.FilePacketConn(socketFile) + } else { + ln, err = config.ListenPacket(ctx, network, address) + } + } else { + if fd { + ln, err = net.FileListener(socketFile) + } else { + ln, err = config.Listen(ctx, network, address) + } + } + if err == nil { listenerPool.LoadOrStore(lnKey, nil) } - // if new listener is a unix socket, make sure we can reuse it later - // (we do our own "unlink on close" -- not required, but more tidy) - one := int32(1) - if unix, ok := ln.(*net.UnixListener); ok { - unix.SetUnlinkOnClose(false) - ln = &unixListener{unix, lnKey, &one} - unixSockets[lnKey] = ln.(*unixListener) - } - - // TODO: Not 100% sure this is necessary, but we do this for net.UnixListener in listen_unix.go, so... - if unix, ok := ln.(*net.UnixConn); ok { - ln = &unixConn{unix, address, lnKey, &one} - unixSockets[lnKey] = ln.(*unixConn) - } - - // lightly wrap the listener so that when it is closed, - // we can decrement the usage pool counter - switch specificLn := ln.(type) { - case net.Listener: - return deleteListener{specificLn, lnKey}, err - case net.PacketConn: - return deletePacketConn{specificLn, lnKey}, err + if datagram { + if !fd { + // TODO: Not 100% sure this is necessary, but we do this for net.UnixListener, so... + if unix, ok := ln.(*net.UnixConn); ok { + one := int32(1) + ln = &unixConn{unix, lnKey, &one} + unixSockets[lnKey] = ln.(*unixConn) + } + } + // lightly wrap the connection so that when it is closed, + // we can decrement the usage pool counter + if specificLn, ok := ln.(net.PacketConn); ok { + ln = deletePacketConn{specificLn, lnKey} + } + } else { + if !fd { + // if new listener is a unix socket, make sure we can reuse it later + // (we do our own "unlink on close" -- not required, but more tidy) + if unix, ok := ln.(*net.UnixListener); ok { + unix.SetUnlinkOnClose(false) + one := int32(1) + ln = &unixListener{unix, lnKey, &one} + unixSockets[lnKey] = ln.(*unixListener) + } + } + // lightly wrap the listener so that when it is closed, + // we can decrement the usage pool counter + if specificLn, ok := ln.(net.Listener); ok { + ln = deleteListener{specificLn, lnKey} + } } // other types, I guess we just return them directly @@ -170,12 +222,18 @@ type unixListener struct { func (uln *unixListener) Close() error { newCount := atomic.AddInt32(uln.count, -1) if newCount == 0 { + file, err := uln.File() + var name string + if err == nil { + name = file.Name() + } defer func() { - addr := uln.Addr().String() unixSocketsMu.Lock() delete(unixSockets, uln.mapKey) unixSocketsMu.Unlock() - _ = syscall.Unlink(addr) + if err == nil { + _ = syscall.Unlink(name) + } }() } return uln.UnixListener.Close() @@ -183,19 +241,25 @@ func (uln *unixListener) Close() error { type unixConn struct { *net.UnixConn - filename string - mapKey string - count *int32 // accessed atomically + mapKey string + count *int32 // accessed atomically } func (uc *unixConn) Close() error { newCount := atomic.AddInt32(uc.count, -1) if newCount == 0 { + file, err := uc.File() + var name string + if err == nil { + name = file.Name() + } defer func() { unixSocketsMu.Lock() delete(unixSockets, uc.mapKey) unixSocketsMu.Unlock() - _ = syscall.Unlink(uc.filename) + if err == nil { + _ = syscall.Unlink(name) + } }() } return uc.UnixConn.Close() @@ -211,6 +275,12 @@ var unixSockets = make(map[string]interface { File() (*os.File, error) }) +// socketFiles is a fd -> *os.File map used to make a FileListener/FilePacketConn from a socket file descriptor. +var socketFiles = map[uintptr]*os.File{} + +// socketFilesMu synchronizes socketFiles insertions +var socketFilesMu sync.Mutex + // deleteListener is a type that simply deletes itself // from the listenerPool when it closes. It is used // solely for the purpose of reference counting (i.e. diff --git a/listeners.go b/listeners.go index 0d9dd753..cf7b5201 100644 --- a/listeners.go +++ b/listeners.go @@ -58,7 +58,7 @@ type NetworkAddress struct { EndPort uint } -// ListenAll calls Listen() for all addresses represented by this struct, i.e. all ports in the range. +// ListenAll calls Listen for all addresses represented by this struct, i.e. all ports in the range. // (If the address doesn't use ports or has 1 port only, then only 1 listener will be created.) // It returns an error if any listener failed to bind, and closes any listeners opened up to that point. func (na NetworkAddress) ListenAll(ctx context.Context, config net.ListenConfig) ([]any, error) { @@ -106,7 +106,8 @@ func (na NetworkAddress) ListenAll(ctx context.Context, config net.ListenConfig) // portOffset to the start port. (For network types that do not use ports, the // portOffset is ignored.) // -// The provided ListenConfig is used to create the listener. Its Control function, +// First Listen checks if a plugin can provide a listener from this address. Otherwise, +// the provided ListenConfig is used to create the listener. Its Control function, // if set, may be wrapped by an internally-used Control function. The provided // context may be used to cancel long operations early. The context is not used // to close the listener after it has been created. @@ -129,6 +130,8 @@ func (na NetworkAddress) ListenAll(ctx context.Context, config net.ListenConfig) // Unix sockets will be unlinked before being created, to ensure we can bind to // it even if the previous program using it exited uncleanly; it will also be // unlinked upon a graceful exit (or when a new config does not use that socket). +// Listen synchronizes binds to unix domain sockets to avoid race conditions +// while an existing socket is unlinked. func (na NetworkAddress) Listen(ctx context.Context, portOffset uint, config net.ListenConfig) (any, error) { if na.IsUnixNetwork() { unixSocketsMu.Lock() @@ -146,54 +149,53 @@ func (na NetworkAddress) Listen(ctx context.Context, portOffset uint, config net func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net.ListenConfig) (any, error) { var ( - ln any - err error - address string - unixFileMode fs.FileMode - isAbstractUnixSocket bool + ln any + err error + address string + unixFileMode fs.FileMode ) // split unix socket addr early so lnKey // is independent of permissions bits if na.IsUnixNetwork() { - var err error address, unixFileMode, err = internal.SplitUnixSocketPermissionsBits(na.Host) if err != nil { return nil, err } - isAbstractUnixSocket = strings.HasPrefix(address, "@") + } else if na.IsFdNetwork() { + address = na.Host } else { address = na.JoinHostPort(portOffset) } - // if this is a unix socket, see if we already have it open, - // force socket permissions on it and return early - if socket, err := reuseUnixSocket(na.Network, address); socket != nil || err != nil { - if !isAbstractUnixSocket { - if err := os.Chmod(address, unixFileMode); err != nil { - return nil, fmt.Errorf("unable to set permissions (%s) on %s: %v", unixFileMode, address, err) - } - } - return socket, err - } - - lnKey := listenerKey(na.Network, address) - if strings.HasPrefix(na.Network, "ip") { ln, err = config.ListenPacket(ctx, na.Network, address) } else { - ln, err = listenReusable(ctx, lnKey, na.Network, address, config) - } - if err != nil { - return nil, err + if na.IsUnixNetwork() { + // if this is a unix socket, see if we already have it open + ln, err = reuseUnixSocket(na.Network, address) + } + + if ln == nil && err == nil { + // otherwise, create a new listener + lnKey := listenerKey(na.Network, address) + ln, err = listenReusable(ctx, lnKey, na.Network, address, config) + } } + if ln == nil { return nil, fmt.Errorf("unsupported network type: %s", na.Network) } + if err != nil { + return nil, err + } + if IsUnixNetwork(na.Network) { + isAbstractUnixSocket := strings.HasPrefix(address, "@") if !isAbstractUnixSocket { - if err := os.Chmod(address, unixFileMode); err != nil { + err = os.Chmod(address, unixFileMode) + if err != nil { return nil, fmt.Errorf("unable to set permissions (%s) on %s: %v", unixFileMode, address, err) } } @@ -208,13 +210,19 @@ func (na NetworkAddress) IsUnixNetwork() bool { return IsUnixNetwork(na.Network) } +// IsUnixNetwork returns true if na.Network is +// fd or fdgram. +func (na NetworkAddress) IsFdNetwork() bool { + return IsFdNetwork(na.Network) +} + // JoinHostPort is like net.JoinHostPort, but where the port // is StartPort + offset. func (na NetworkAddress) JoinHostPort(offset uint) string { - if na.IsUnixNetwork() { + if na.IsUnixNetwork() || na.IsFdNetwork() { return na.Host } - return net.JoinHostPort(na.Host, strconv.Itoa(int(na.StartPort+offset))) + return net.JoinHostPort(na.Host, strconv.FormatUint(uint64(na.StartPort+offset), 10)) } // Expand returns one NetworkAddress for each port in the port range. @@ -248,7 +256,7 @@ func (na NetworkAddress) PortRangeSize() uint { } func (na NetworkAddress) isLoopback() bool { - if na.IsUnixNetwork() { + if na.IsUnixNetwork() || na.IsFdNetwork() { return true } if na.Host == "localhost" { @@ -292,6 +300,30 @@ func IsUnixNetwork(netw string) bool { return strings.HasPrefix(netw, "unix") } +// IsFdNetwork returns true if the netw is a fd network. +func IsFdNetwork(netw string) bool { + return strings.HasPrefix(netw, "fd") +} + +// normally we would simply append the port, +// but if host is IPv6, we need to ensure it +// is enclosed in [ ]; net.JoinHostPort does +// this for us, but host might also have a +// network type in front (e.g. "tcp/") leading +// to "[tcp/::1]" which causes parsing failures +// later; what we need is "tcp/[::1]", so we have +// to split the network and host, then re-combine +func ParseNetworkAddressFromHostPort(host, port string) (NetworkAddress, error) { + network, addr, ok := strings.Cut(host, "/") + if !ok { + addr = network + network = "" + } + addr = strings.Trim(addr, "[]") // IPv6 + networkAddr := JoinNetworkAddress(network, addr, port) + return ParseNetworkAddress(networkAddr) +} + // ParseNetworkAddress parses addr into its individual // components. The input string is expected to be of // the form "network/host:port-range" where any part is @@ -322,6 +354,12 @@ func ParseNetworkAddressWithDefaults(addr, defaultNetwork string, defaultPort ui Host: host, }, err } + if IsFdNetwork(network) { + return NetworkAddress{ + Network: network, + Host: host, + }, nil + } var start, end uint64 if port == "" { start = uint64(defaultPort) @@ -362,7 +400,7 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) { network = strings.ToLower(strings.TrimSpace(beforeSlash)) a = afterSlash } - if IsUnixNetwork(network) { + if IsUnixNetwork(network) || IsFdNetwork(network) { host = a return } @@ -393,7 +431,7 @@ func JoinNetworkAddress(network, host, port string) string { if network != "" { a = network + "/" } - if (host != "" && port == "") || IsUnixNetwork(network) { + if (host != "" && port == "") || IsUnixNetwork(network) || IsFdNetwork(network) { a += host } else if port != "" { a += net.JoinHostPort(host, port) @@ -401,9 +439,11 @@ func JoinNetworkAddress(network, host, port string) string { return a } -// ListenQUIC returns a quic.EarlyListener suitable for use in a Caddy module. -// The network will be transformed into a QUIC-compatible type (if unix, then -// unixgram will be used; otherwise, udp will be used). +// ListenQUIC returns a http3.QUICEarlyListener suitable for use in a Caddy module. +// +// The network will be transformed into a QUIC-compatible type if the same address can be used with +// different networks. Currently this just means that for tcp, udp will be used with the same +// address instead. // // NOTE: This API is EXPERIMENTAL and may be changed or removed. func (na NetworkAddress) ListenQUIC(ctx context.Context, portOffset uint, config net.ListenConfig, tlsConf *tls.Config) (http3.QUICEarlyListener, error) { @@ -617,7 +657,8 @@ func RegisterNetwork(network string, getListener ListenerFunc) { if network == "tcp" || network == "tcp4" || network == "tcp6" || network == "udp" || network == "udp4" || network == "udp6" || network == "unix" || network == "unixpacket" || network == "unixgram" || - strings.HasPrefix("ip:", network) || strings.HasPrefix("ip4:", network) || strings.HasPrefix("ip6:", network) { + strings.HasPrefix("ip:", network) || strings.HasPrefix("ip4:", network) || strings.HasPrefix("ip6:", network) || + network == "fd" || network == "fdgram" { panic("network type " + network + " is reserved") } diff --git a/modules/caddyhttp/app.go b/modules/caddyhttp/app.go index 5dbecf9b..673ebcb8 100644 --- a/modules/caddyhttp/app.go +++ b/modules/caddyhttp/app.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "fmt" + "maps" "net" "net/http" "strconv" @@ -203,17 +204,75 @@ func (app *App) Provision(ctx caddy.Context) error { } } - // the Go standard library does not let us serve only HTTP/2 using - // http.Server; we would probably need to write our own server - if !srv.protocol("h1") && (srv.protocol("h2") || srv.protocol("h2c")) { - return fmt.Errorf("server %s: cannot enable HTTP/2 or H2C without enabling HTTP/1.1; add h1 to protocols or remove h2/h2c", srvName) - } - // if no protocols configured explicitly, enable all except h2c if len(srv.Protocols) == 0 { srv.Protocols = []string{"h1", "h2", "h3"} } + srvProtocolsUnique := map[string]struct{}{} + for _, srvProtocol := range srv.Protocols { + srvProtocolsUnique[srvProtocol] = struct{}{} + } + _, h1ok := srvProtocolsUnique["h1"] + _, h2ok := srvProtocolsUnique["h2"] + _, h2cok := srvProtocolsUnique["h2c"] + + // the Go standard library does not let us serve only HTTP/2 using + // http.Server; we would probably need to write our own server + if !h1ok && (h2ok || h2cok) { + return fmt.Errorf("server %s: cannot enable HTTP/2 or H2C without enabling HTTP/1.1; add h1 to protocols or remove h2/h2c", srvName) + } + + if srv.ListenProtocols != nil { + if len(srv.ListenProtocols) != len(srv.Listen) { + return fmt.Errorf("server %s: listener protocols count does not match address count: %d != %d", + srvName, len(srv.ListenProtocols), len(srv.Listen)) + } + + for i, lnProtocols := range srv.ListenProtocols { + if lnProtocols != nil { + // populate empty listen protocols with server protocols + lnProtocolsDefault := false + var lnProtocolsInclude []string + srvProtocolsInclude := maps.Clone(srvProtocolsUnique) + + // keep existing listener protocols unless they are empty + for _, lnProtocol := range lnProtocols { + if lnProtocol == "" { + lnProtocolsDefault = true + } else { + lnProtocolsInclude = append(lnProtocolsInclude, lnProtocol) + delete(srvProtocolsInclude, lnProtocol) + } + } + + // append server protocols to listener protocols if any listener protocols were empty + if lnProtocolsDefault { + for _, srvProtocol := range srv.Protocols { + if _, ok := srvProtocolsInclude[srvProtocol]; ok { + lnProtocolsInclude = append(lnProtocolsInclude, srvProtocol) + } + } + } + + lnProtocolsIncludeUnique := map[string]struct{}{} + for _, lnProtocol := range lnProtocolsInclude { + lnProtocolsIncludeUnique[lnProtocol] = struct{}{} + } + _, h1ok := lnProtocolsIncludeUnique["h1"] + _, h2ok := lnProtocolsIncludeUnique["h2"] + _, h2cok := lnProtocolsIncludeUnique["h2c"] + + // check if any listener protocols contain h2 or h2c without h1 + if !h1ok && (h2ok || h2cok) { + return fmt.Errorf("server %s, listener %d: cannot enable HTTP/2 or H2C without enabling HTTP/1.1; add h1 to protocols or remove h2/h2c", srvName, i) + } + + srv.ListenProtocols[i] = lnProtocolsInclude + } + } + } + // if not explicitly configured by the user, disallow TLS // client auth bypass (domain fronting) which could // otherwise be exploited by sending an unprotected SNI @@ -344,7 +403,7 @@ func (app *App) Validate() error { // check that every address in the port range is unique to this server; // we do not use <= here because PortRangeSize() adds 1 to EndPort for us for i := uint(0); i < listenAddr.PortRangeSize(); i++ { - addr := caddy.JoinNetworkAddress(listenAddr.Network, listenAddr.Host, strconv.Itoa(int(listenAddr.StartPort+i))) + addr := caddy.JoinNetworkAddress(listenAddr.Network, listenAddr.Host, strconv.FormatUint(uint64(listenAddr.StartPort+i), 10)) if sn, ok := lnAddrs[addr]; ok { return fmt.Errorf("server %s: listener address repeated: %s (already claimed by server '%s')", srvName, addr, sn) } @@ -422,99 +481,118 @@ func (app *App) Start() error { srv.server.Handler = h2c.NewHandler(srv, h2server) } - for _, lnAddr := range srv.Listen { + for lnIndex, lnAddr := range srv.Listen { listenAddr, err := caddy.ParseNetworkAddress(lnAddr) if err != nil { return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err) } + srv.addresses = append(srv.addresses, listenAddr) - for portOffset := uint(0); portOffset < listenAddr.PortRangeSize(); portOffset++ { - // create the listener for this socket - hostport := listenAddr.JoinHostPort(portOffset) - lnAny, err := listenAddr.Listen(app.ctx, portOffset, net.ListenConfig{KeepAlive: time.Duration(srv.KeepAliveInterval)}) - if err != nil { - return fmt.Errorf("listening on %s: %v", listenAddr.At(portOffset), err) - } - ln := lnAny.(net.Listener) + protocols := srv.Protocols + if srv.ListenProtocols != nil && srv.ListenProtocols[lnIndex] != nil { + protocols = srv.ListenProtocols[lnIndex] + } - // wrap listener before TLS (up to the TLS placeholder wrapper) - var lnWrapperIdx int - for i, lnWrapper := range srv.listenerWrappers { - if _, ok := lnWrapper.(*tlsPlaceholderWrapper); ok { - lnWrapperIdx = i + 1 // mark the next wrapper's spot - break - } - ln = lnWrapper.WrapListener(ln) - } + protocolsUnique := map[string]struct{}{} + for _, protocol := range protocols { + protocolsUnique[protocol] = struct{}{} + } + _, h1ok := protocolsUnique["h1"] + _, h2ok := protocolsUnique["h2"] + _, h2cok := protocolsUnique["h2c"] + _, h3ok := protocolsUnique["h3"] + + for portOffset := uint(0); portOffset < listenAddr.PortRangeSize(); portOffset++ { + hostport := listenAddr.JoinHostPort(portOffset) // enable TLS if there is a policy and if this is not the HTTP port useTLS := len(srv.TLSConnPolicies) > 0 && int(listenAddr.StartPort+portOffset) != app.httpPort() - if useTLS { - // create TLS listener - this enables and terminates TLS - ln = tls.NewListener(ln, tlsCfg) - // enable HTTP/3 if configured - if srv.protocol("h3") { - // Can't serve HTTP/3 on the same socket as HTTP/1 and 2 because it uses - // a different transport mechanism... which is fine, but the OS doesn't - // differentiate between a SOCK_STREAM file and a SOCK_DGRAM file; they - // are still one file on the system. So even though "unixpacket" and - // "unixgram" are different network types just as "tcp" and "udp" are, - // the OS will not let us use the same file as both STREAM and DGRAM. - if len(srv.Protocols) > 1 && listenAddr.IsUnixNetwork() { - app.logger.Warn("HTTP/3 disabled because Unix can't multiplex STREAM and DGRAM on same socket", - zap.String("file", hostport)) - for i := range srv.Protocols { - if srv.Protocols[i] == "h3" { - srv.Protocols = append(srv.Protocols[:i], srv.Protocols[i+1:]...) - break - } - } - } else { - app.logger.Info("enabling HTTP/3 listener", zap.String("addr", hostport)) - if err := srv.serveHTTP3(listenAddr.At(portOffset), tlsCfg); err != nil { - return err - } + // enable HTTP/3 if configured + if h3ok && useTLS { + app.logger.Info("enabling HTTP/3 listener", zap.String("addr", hostport)) + if err := srv.serveHTTP3(listenAddr.At(portOffset), tlsCfg); err != nil { + return err + } + } + + if h3ok && !useTLS { + // Can only serve h3 with TLS enabled + app.logger.Warn("HTTP/3 skipped because it requires TLS", + zap.String("network", listenAddr.Network), + zap.String("addr", hostport)) + } + + if h1ok || h2ok && useTLS || h2cok { + // create the listener for this socket + lnAny, err := listenAddr.Listen(app.ctx, portOffset, net.ListenConfig{KeepAlive: time.Duration(srv.KeepAliveInterval)}) + if err != nil { + return fmt.Errorf("listening on %s: %v", listenAddr.At(portOffset), err) + } + ln, ok := lnAny.(net.Listener) + if !ok { + return fmt.Errorf("network '%s' cannot handle HTTP/1 or HTTP/2 connections", listenAddr.Network) + } + + if useTLS { + // create TLS listener - this enables and terminates TLS + ln = tls.NewListener(ln, tlsCfg) + } + + // wrap listener before TLS (up to the TLS placeholder wrapper) + var lnWrapperIdx int + for i, lnWrapper := range srv.listenerWrappers { + if _, ok := lnWrapper.(*tlsPlaceholderWrapper); ok { + lnWrapperIdx = i + 1 // mark the next wrapper's spot + break } + ln = lnWrapper.WrapListener(ln) + } + + // finish wrapping listener where we left off before TLS + for i := lnWrapperIdx; i < len(srv.listenerWrappers); i++ { + ln = srv.listenerWrappers[i].WrapListener(ln) + } + + // handle http2 if use tls listener wrapper + if h2ok { + http2lnWrapper := &http2Listener{ + Listener: ln, + server: srv.server, + h2server: h2server, + } + srv.h2listeners = append(srv.h2listeners, http2lnWrapper) + ln = http2lnWrapper + } + + // if binding to port 0, the OS chooses a port for us; + // but the user won't know the port unless we print it + if !listenAddr.IsUnixNetwork() && !listenAddr.IsFdNetwork() && listenAddr.StartPort == 0 && listenAddr.EndPort == 0 { + app.logger.Info("port 0 listener", + zap.String("input_address", lnAddr), + zap.String("actual_address", ln.Addr().String())) + } + + app.logger.Debug("starting server loop", + zap.String("address", ln.Addr().String()), + zap.Bool("tls", useTLS), + zap.Bool("http3", srv.h3server != nil)) + + srv.listeners = append(srv.listeners, ln) + + // enable HTTP/1 if configured + if h1ok { + //nolint:errcheck + go srv.server.Serve(ln) } } - // finish wrapping listener where we left off before TLS - for i := lnWrapperIdx; i < len(srv.listenerWrappers); i++ { - ln = srv.listenerWrappers[i].WrapListener(ln) - } - - // handle http2 if use tls listener wrapper - if useTLS { - http2lnWrapper := &http2Listener{ - Listener: ln, - server: srv.server, - h2server: h2server, - } - srv.h2listeners = append(srv.h2listeners, http2lnWrapper) - ln = http2lnWrapper - } - - // if binding to port 0, the OS chooses a port for us; - // but the user won't know the port unless we print it - if !listenAddr.IsUnixNetwork() && listenAddr.StartPort == 0 && listenAddr.EndPort == 0 { - app.logger.Info("port 0 listener", - zap.String("input_address", lnAddr), - zap.String("actual_address", ln.Addr().String())) - } - - app.logger.Debug("starting server loop", - zap.String("address", ln.Addr().String()), - zap.Bool("tls", useTLS), - zap.Bool("http3", srv.h3server != nil)) - - srv.listeners = append(srv.listeners, ln) - - // enable HTTP/1 if configured - if srv.protocol("h1") { - //nolint:errcheck - go srv.server.Serve(ln) + if h2ok && !useTLS { + // Can only serve h2 with TLS enabled + app.logger.Warn("HTTP/2 skipped because it requires TLS", + zap.String("network", listenAddr.Network), + zap.String("addr", hostport)) } } } diff --git a/modules/caddyhttp/proxyprotocol/listenerwrapper.go b/modules/caddyhttp/proxyprotocol/listenerwrapper.go index 440e7071..f5f2099c 100644 --- a/modules/caddyhttp/proxyprotocol/listenerwrapper.go +++ b/modules/caddyhttp/proxyprotocol/listenerwrapper.go @@ -72,7 +72,7 @@ func (pp *ListenerWrapper) Provision(ctx caddy.Context) error { pp.policy = func(options goproxy.ConnPolicyOptions) (goproxy.Policy, error) { // trust unix sockets - if network := options.Upstream.Network(); caddy.IsUnixNetwork(network) { + if network := options.Upstream.Network(); caddy.IsUnixNetwork(network) || caddy.IsFdNetwork(network) { return goproxy.USE, nil } ret := pp.FallbackPolicy diff --git a/modules/caddyhttp/reverseproxy/addresses.go b/modules/caddyhttp/reverseproxy/addresses.go index 82c1c799..31f4aeb3 100644 --- a/modules/caddyhttp/reverseproxy/addresses.go +++ b/modules/caddyhttp/reverseproxy/addresses.go @@ -137,7 +137,7 @@ func parseUpstreamDialAddress(upstreamAddr string) (parsedAddr, error) { } // we can assume a port if only a hostname is specified, but use of a // placeholder without a port likely means a port will be filled in - if port == "" && !strings.Contains(host, "{") && !caddy.IsUnixNetwork(network) { + if port == "" && !strings.Contains(host, "{") && !caddy.IsUnixNetwork(network) && !caddy.IsFdNetwork(network) { port = "80" } } diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 319cc924..1735e45a 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -330,7 +330,7 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { return } if hcp := uint(upstream.activeHealthCheckPort); hcp != 0 { - if addr.IsUnixNetwork() { + if addr.IsUnixNetwork() || addr.IsFdNetwork() { addr.Network = "tcp" // I guess we just assume TCP since we are using a port?? } addr.StartPort, addr.EndPort = hcp, hcp @@ -345,7 +345,7 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { } hostAddr := addr.JoinHostPort(0) dialAddr := hostAddr - if addr.IsUnixNetwork() { + if addr.IsUnixNetwork() || addr.IsFdNetwork() { // this will be used as the Host portion of a http.Request URL, and // paths to socket files would produce an error when creating URL, // so use a fake Host value instead; unix sockets are usually local diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index f5478cb3..5aa7e0f6 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -220,6 +220,10 @@ type Server struct { // Default: `[h1 h2 h3]` Protocols []string `json:"protocols,omitempty"` + // ListenProtocols overrides Protocols for each parallel address in Listen. + // A nil value or element indicates that Protocols will be used instead. + ListenProtocols [][]string `json:"listen_protocols,omitempty"` + // If set, metrics observations will be enabled. // This setting is EXPERIMENTAL and subject to change. Metrics *Metrics `json:"metrics,omitempty"` @@ -597,7 +601,11 @@ func (s *Server) findLastRouteWithHostMatcher() int { // not already done, and then uses that server to serve HTTP/3 over // the listener, with Server s as the handler. func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error { - addr.Network = getHTTP3Network(addr.Network) + h3net, err := getHTTP3Network(addr.Network) + if err != nil { + return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err) + } + addr.Network = h3net h3ln, err := addr.ListenQUIC(s.ctx, 0, net.ListenConfig{}, tlsCfg) if err != nil { return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err) @@ -849,7 +857,21 @@ func (s *Server) logRequest( // protocol returns true if the protocol proto is configured/enabled. func (s *Server) protocol(proto string) bool { - return slices.Contains(s.Protocols, proto) + if s.ListenProtocols == nil { + if slices.Contains(s.Protocols, proto) { + return true + } + } else { + for _, lnProtocols := range s.ListenProtocols { + for _, lnProtocol := range lnProtocols { + if lnProtocol == "" && slices.Contains(s.Protocols, proto) || lnProtocol == proto { + return true + } + } + } + } + + return false } // Listeners returns the server's listeners. These are active listeners, @@ -1089,9 +1111,14 @@ const ( ) var networkTypesHTTP3 = map[string]string{ - "unix": "unixgram", - "tcp4": "udp4", - "tcp6": "udp6", + "unixgram": "unixgram", + "udp": "udp", + "udp4": "udp4", + "udp6": "udp6", + "tcp": "udp", + "tcp4": "udp4", + "tcp6": "udp6", + "fdgram": "fdgram", } // RegisterNetworkHTTP3 registers a mapping from non-HTTP/3 network to HTTP/3 @@ -1106,11 +1133,10 @@ func RegisterNetworkHTTP3(originalNetwork, h3Network string) { networkTypesHTTP3[originalNetwork] = h3Network } -func getHTTP3Network(originalNetwork string) string { +func getHTTP3Network(originalNetwork string) (string, error) { h3Network, ok := networkTypesHTTP3[strings.ToLower(originalNetwork)] if !ok { - // TODO: Maybe a better default is to not enable HTTP/3 if we do not know the network? - return "udp" + return "", fmt.Errorf("network '%s' cannot handle HTTP/3 connections", originalNetwork) } - return h3Network + return h3Network, nil } diff --git a/modules/caddyhttp/staticresp.go b/modules/caddyhttp/staticresp.go index 1fea6978..1b93ede4 100644 --- a/modules/caddyhttp/staticresp.go +++ b/modules/caddyhttp/staticresp.go @@ -387,7 +387,7 @@ func cmdRespond(fl caddycmd.Flags) (int, error) { return caddy.ExitCodeFailedStartup, err } - if !listenAddr.IsUnixNetwork() { + if !listenAddr.IsUnixNetwork() && !listenAddr.IsFdNetwork() { listenAddrs := make([]string, 0, listenAddr.PortRangeSize()) for offset := uint(0); offset < listenAddr.PortRangeSize(); offset++ { listenAddrs = append(listenAddrs, listenAddr.JoinHostPort(offset)) diff --git a/replacer.go b/replacer.go index 65815c92..297dd935 100644 --- a/replacer.go +++ b/replacer.go @@ -299,11 +299,11 @@ func ToString(val any) string { case int64: return strconv.Itoa(int(v)) case uint: - return strconv.Itoa(int(v)) + return strconv.FormatUint(uint64(v), 10) case uint32: - return strconv.Itoa(int(v)) + return strconv.FormatUint(uint64(v), 10) case uint64: - return strconv.Itoa(int(v)) + return strconv.FormatUint(v, 10) case float32: return strconv.FormatFloat(float64(v), 'f', -1, 32) case float64: