package capnp import ( "encoding/binary" "errors" ) // A SegmentID is a numeric identifier for a Segment. type SegmentID uint32 // A Segment is an allocation arena for Cap'n Proto objects. // It is part of a Message, which can contain other segments that // reference each other. type Segment struct { msg *Message id SegmentID data []byte } // Message returns the message that contains s. func (s *Segment) Message() *Message { return s.msg } // ID returns the segment's ID. func (s *Segment) ID() SegmentID { return s.id } // Data returns the raw byte slice for the segment. func (s *Segment) Data() []byte { return s.data } func (s *Segment) inBounds(addr Address) bool { return addr < Address(len(s.data)) } func (s *Segment) regionInBounds(base Address, sz Size) bool { end, ok := base.addSize(sz) if !ok { return false } return end <= Address(len(s.data)) } // slice returns the segment of data from base to base+sz. func (s *Segment) slice(base Address, sz Size) []byte { // Bounds check should have happened before calling slice. return s.data[base : base+Address(sz)] } func (s *Segment) readUint8(addr Address) uint8 { return s.slice(addr, 1)[0] } func (s *Segment) readUint16(addr Address) uint16 { return binary.LittleEndian.Uint16(s.slice(addr, 2)) } func (s *Segment) readUint32(addr Address) uint32 { return binary.LittleEndian.Uint32(s.slice(addr, 4)) } func (s *Segment) readUint64(addr Address) uint64 { return binary.LittleEndian.Uint64(s.slice(addr, 8)) } func (s *Segment) readRawPointer(addr Address) rawPointer { return rawPointer(s.readUint64(addr)) } func (s *Segment) writeUint8(addr Address, val uint8) { s.slice(addr, 1)[0] = val } func (s *Segment) writeUint16(addr Address, val uint16) { binary.LittleEndian.PutUint16(s.slice(addr, 2), val) } func (s *Segment) writeUint32(addr Address, val uint32) { binary.LittleEndian.PutUint32(s.slice(addr, 4), val) } func (s *Segment) writeUint64(addr Address, val uint64) { binary.LittleEndian.PutUint64(s.slice(addr, 8), val) } func (s *Segment) writeRawPointer(addr Address, val rawPointer) { s.writeUint64(addr, uint64(val)) } // root returns a 1-element pointer list that references the first word // in the segment. This only makes sense to call on the first segment // in a message. func (s *Segment) root() PointerList { sz := ObjectSize{PointerCount: 1} if !s.regionInBounds(0, sz.totalSize()) { return PointerList{} } return PointerList{List{ seg: s, length: 1, size: sz, depthLimit: s.msg.depthLimit(), }} } func (s *Segment) lookupSegment(id SegmentID) (*Segment, error) { if s.id == id { return s, nil } return s.msg.Segment(id) } func (s *Segment) readPtr(paddr Address, depthLimit uint) (ptr Ptr, err error) { s, base, val, err := s.resolveFarPointer(paddr) if err != nil { return Ptr{}, err } if val == 0 { return Ptr{}, nil } if depthLimit == 0 { return Ptr{}, errDepthLimit } switch val.pointerType() { case structPointer: sp, err := s.readStructPtr(base, val) if err != nil { return Ptr{}, err } if !s.msg.ReadLimiter().canRead(sp.readSize()) { return Ptr{}, errReadLimit } sp.depthLimit = depthLimit - 1 return sp.ToPtr(), nil case listPointer: lp, err := s.readListPtr(base, val) if err != nil { return Ptr{}, err } if !s.msg.ReadLimiter().canRead(lp.readSize()) { return Ptr{}, errReadLimit } lp.depthLimit = depthLimit - 1 return lp.ToPtr(), nil case otherPointer: if val.otherPointerType() != 0 { return Ptr{}, errOtherPointer } return Interface{ seg: s, cap: val.capabilityIndex(), }.ToPtr(), nil default: // Only other types are far pointers. return Ptr{}, errBadLandingPad } } func (s *Segment) readStructPtr(base Address, val rawPointer) (Struct, error) { addr, ok := val.offset().resolve(base) if !ok { return Struct{}, errPointerAddress } sz := val.structSize() if !s.regionInBounds(addr, sz.totalSize()) { return Struct{}, errPointerAddress } return Struct{ seg: s, off: addr, size: sz, }, nil } func (s *Segment) readListPtr(base Address, val rawPointer) (List, error) { addr, ok := val.offset().resolve(base) if !ok { return List{}, errPointerAddress } lsize, ok := val.totalListSize() if !ok { return List{}, errOverflow } if !s.regionInBounds(addr, lsize) { return List{}, errPointerAddress } lt := val.listType() if lt == compositeList { hdr := s.readRawPointer(addr) var ok bool addr, ok = addr.addSize(wordSize) if !ok { return List{}, errOverflow } if hdr.pointerType() != structPointer { return List{}, errBadTag } sz := hdr.structSize() n := int32(hdr.offset()) if n < 0 { return List{}, errListSize } // TODO(light): check that this has the same end address if tsize, ok := sz.totalSize().times(n); !ok { return List{}, errOverflow } else if !s.regionInBounds(addr, tsize) { return List{}, errPointerAddress } return List{ seg: s, size: sz, off: addr, length: n, flags: isCompositeList, }, nil } n := val.numListElements() if n < 0 { return List{}, errListSize } if lt == bit1List { return List{ seg: s, off: addr, length: n, flags: isBitList, }, nil } return List{ seg: s, size: val.elementSize(), off: addr, length: n, }, nil } func (s *Segment) resolveFarPointer(paddr Address) (dst *Segment, base Address, resolved rawPointer, err error) { // Encoding details at https://capnproto.org/encoding.html#inter-segment-pointers val := s.readRawPointer(paddr) switch val.pointerType() { case doubleFarPointer: padSeg, err := s.lookupSegment(val.farSegment()) if err != nil { return nil, 0, 0, err } padAddr := val.farAddress() if !padSeg.regionInBounds(padAddr, wordSize*2) { return nil, 0, 0, errPointerAddress } far := padSeg.readRawPointer(padAddr) if far.pointerType() != farPointer { return nil, 0, 0, errBadLandingPad } tagAddr, ok := padAddr.addSize(wordSize) if !ok { return nil, 0, 0, errOverflow } tag := padSeg.readRawPointer(tagAddr) if pt := tag.pointerType(); (pt != structPointer && pt != listPointer) || tag.offset() != 0 { return nil, 0, 0, errBadLandingPad } if dst, err = s.lookupSegment(far.farSegment()); err != nil { return nil, 0, 0, err } return dst, 0, landingPadNearPointer(far, tag), nil case farPointer: var err error dst, err = s.lookupSegment(val.farSegment()) if err != nil { return nil, 0, 0, err } padAddr := val.farAddress() if !dst.regionInBounds(padAddr, wordSize) { return nil, 0, 0, errPointerAddress } var ok bool base, ok = padAddr.addSize(wordSize) if !ok { return nil, 0, 0, errOverflow } return dst, base, dst.readRawPointer(padAddr), nil default: var ok bool base, ok = paddr.addSize(wordSize) if !ok { return nil, 0, 0, errOverflow } return s, base, val, nil } } func (s *Segment) writePtr(off Address, src Ptr, forceCopy bool) error { if !src.IsValid() { s.writeRawPointer(off, 0) return nil } // Copy src, if needed, and process pointers where placement is // irrelevant (capabilities and zero-sized structs). var srcAddr Address var srcRaw rawPointer switch src.flags.ptrType() { case structPtrType: st := src.Struct() if st.size.isZero() { // Zero-sized structs should always be encoded with offset -1 in // order to avoid conflating with null. No allocation needed. s.writeRawPointer(off, rawStructPointer(-1, ObjectSize{})) return nil } if forceCopy || src.seg.msg != s.msg || st.flags&isListMember != 0 { newSeg, newAddr, err := alloc(s, st.size.totalSize()) if err != nil { return err } dst := Struct{ seg: newSeg, off: newAddr, size: st.size, depthLimit: maxDepth, // clear flags } if err := copyStruct(dst, st); err != nil { return err } st = dst src = dst.ToPtr() } srcAddr = st.off srcRaw = rawStructPointer(0, st.size) case listPtrType: l := src.List() if forceCopy || src.seg.msg != s.msg { sz := l.allocSize() newSeg, newAddr, err := alloc(s, sz) if err != nil { return err } dst := List{ seg: newSeg, off: newAddr, length: l.length, size: l.size, flags: l.flags, depthLimit: maxDepth, } if dst.flags&isCompositeList != 0 { // Copy tag word newSeg.writeRawPointer(newAddr, l.seg.readRawPointer(l.off-Address(wordSize))) var ok bool dst.off, ok = dst.off.addSize(wordSize) if !ok { return errOverflow } sz -= wordSize } if dst.flags&isBitList != 0 || dst.size.PointerCount == 0 { end, _ := l.off.addSize(sz) // list was already validated copy(newSeg.data[dst.off:], l.seg.data[l.off:end]) } else { for i := 0; i < l.Len(); i++ { err := copyStruct(dst.Struct(i), l.Struct(i)) if err != nil { return err } } } l = dst src = dst.ToPtr() } srcAddr = l.off if l.flags&isCompositeList != 0 { srcAddr -= Address(wordSize) } srcRaw = l.raw() case interfacePtrType: i := src.Interface() if src.seg.msg != s.msg { c := s.msg.AddCap(i.Client()) i = NewInterface(s, c) } s.writeRawPointer(off, i.value(off)) return nil default: panic("unreachable") } switch { case src.seg == s: // Common case: src is in same segment as pointer. // Use a near pointer. s.writeRawPointer(off, srcRaw.withOffset(nearPointerOffset(off, srcAddr))) return nil case hasCapacity(src.seg.data, wordSize): // Enough room adjacent to src to write a far pointer landing pad. _, padAddr, _ := alloc(src.seg, wordSize) src.seg.writeRawPointer(padAddr, srcRaw.withOffset(nearPointerOffset(padAddr, srcAddr))) s.writeRawPointer(off, rawFarPointer(src.seg.id, padAddr)) return nil default: // Not enough room for a landing pad, need to use a double-far pointer. padSeg, padAddr, err := alloc(s, wordSize*2) if err != nil { return err } padSeg.writeRawPointer(padAddr, rawFarPointer(src.seg.id, srcAddr)) padSeg.writeRawPointer(padAddr+Address(wordSize), srcRaw) s.writeRawPointer(off, rawDoubleFarPointer(padSeg.id, padAddr)) return nil } } var ( errPointerAddress = errors.New("capnp: invalid pointer address") errBadLandingPad = errors.New("capnp: invalid far pointer landing pad") errBadTag = errors.New("capnp: invalid tag word") errOtherPointer = errors.New("capnp: unknown pointer type") errObjectSize = errors.New("capnp: invalid object size") errElementSize = errors.New("capnp: mismatched list element size") errReadLimit = errors.New("capnp: read traversal limit reached") errDepthLimit = errors.New("capnp: depth limit reached") ) var ( errOverflow = errors.New("capnp: address or size overflow") errOutOfBounds = errors.New("capnp: address out of bounds") errCopyDepth = errors.New("capnp: copy depth too large") errOverlap = errors.New("capnp: overlapping data on copy") errListSize = errors.New("capnp: invalid list size") )